annotations.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from ray.rllib.utils.deprecation import Deprecated
  2. from ray.util.annotations import _mark_annotated
  3. def override(cls):
  4. """Decorator for documenting method overrides.
  5. Args:
  6. cls: The superclass that provides the overridden method. If this
  7. cls does not actually have the method, an error is raised.
  8. Examples:
  9. >>> from ray.rllib.policy import Policy
  10. >>> class TorchPolicy(Policy): # doctest: +SKIP
  11. ... ...
  12. ... # Indicates that `TorchPolicy.loss()` overrides the parent
  13. ... # Policy class' own `loss method. Leads to an error if Policy
  14. ... # does not have a `loss` method.
  15. ... @override(Policy) # doctest: +SKIP
  16. ... def loss(self, model, action_dist, train_batch): # doctest: +SKIP
  17. ... ... # doctest: +SKIP
  18. """
  19. def check_override(method):
  20. if method.__name__ not in dir(cls):
  21. raise NameError("{} does not override any method of {}".format(method, cls))
  22. return method
  23. return check_override
  24. def PublicAPI(obj):
  25. """Decorator for documenting public APIs.
  26. Public APIs are classes and methods exposed to end users of RLlib. You
  27. can expect these APIs to remain stable across RLlib releases.
  28. Subclasses that inherit from a ``@PublicAPI`` base class can be
  29. assumed part of the RLlib public API as well (e.g., all Algorithm classes
  30. are in public API because Algorithm is ``@PublicAPI``).
  31. In addition, you can assume all algo configurations are part of their
  32. public API as well.
  33. Examples:
  34. >>> # Indicates that the `Algorithm` class is exposed to end users
  35. >>> # of RLlib and will remain stable across RLlib releases.
  36. >>> from ray import tune
  37. >>> @PublicAPI # doctest: +SKIP
  38. >>> class Algorithm(tune.Trainable): # doctest: +SKIP
  39. ... ... # doctest: +SKIP
  40. """
  41. _mark_annotated(obj)
  42. return obj
  43. def DeveloperAPI(obj):
  44. """Decorator for documenting developer APIs.
  45. Developer APIs are classes and methods explicitly exposed to developers
  46. for the purposes of building custom algorithms or advanced training
  47. strategies on top of RLlib internals. You can generally expect these APIs
  48. to be stable sans minor changes (but less stable than public APIs).
  49. Subclasses that inherit from a ``@DeveloperAPI`` base class can be
  50. assumed part of the RLlib developer API as well.
  51. Examples:
  52. >>> # Indicates that the `TorchPolicy` class is exposed to end users
  53. >>> # of RLlib and will remain (relatively) stable across RLlib
  54. >>> # releases.
  55. >>> from ray.rllib.policy import Policy
  56. >>> @DeveloperAPI # doctest: +SKIP
  57. ... class TorchPolicy(Policy): # doctest: +SKIP
  58. ... ... # doctest: +SKIP
  59. """
  60. _mark_annotated(obj)
  61. return obj
  62. def ExperimentalAPI(obj):
  63. """Decorator for documenting experimental APIs.
  64. Experimental APIs are classes and methods that are in development and may
  65. change at any time in their development process. You should not expect
  66. these APIs to be stable until their tag is changed to `DeveloperAPI` or
  67. `PublicAPI`.
  68. Subclasses that inherit from a ``@ExperimentalAPI`` base class can be
  69. assumed experimental as well.
  70. Examples:
  71. >>> from ray.rllib.policy import Policy
  72. >>> class TorchPolicy(Policy): # doctest: +SKIP
  73. ... ... # doctest: +SKIP
  74. ... # Indicates that the `TorchPolicy.loss` method is a new and
  75. ... # experimental API and may change frequently in future
  76. ... # releases.
  77. ... @ExperimentalAPI # doctest: +SKIP
  78. ... def loss(self, model, action_dist, train_batch): # doctest: +SKIP
  79. ... ... # doctest: +SKIP
  80. """
  81. _mark_annotated(obj)
  82. return obj
  83. def OverrideToImplementCustomLogic(obj):
  84. """Users should override this in their sub-classes to implement custom logic.
  85. Used in Algorithm and Policy to tag methods that need overriding, e.g.
  86. `Policy.loss()`.
  87. Examples:
  88. >>> from ray.rllib.policy.torch_policy import TorchPolicy
  89. >>> @overrides(TorchPolicy) # doctest: +SKIP
  90. ... @OverrideToImplementCustomLogic # doctest: +SKIP
  91. ... def loss(self, ...): # doctest: +SKIP
  92. ... # implement custom loss function here ...
  93. ... # ... w/o calling the corresponding `super().loss()` method.
  94. ... ... # doctest: +SKIP
  95. """
  96. obj.__is_overriden__ = False
  97. return obj
  98. def OverrideToImplementCustomLogic_CallToSuperRecommended(obj):
  99. """Users should override this in their sub-classes to implement custom logic.
  100. Thereby, it is recommended (but not required) to call the super-class'
  101. corresponding method.
  102. Used in Algorithm and Policy to tag methods that need overriding, but the
  103. super class' method should still be called, e.g.
  104. `Algorithm.setup()`.
  105. Examples:
  106. >>> from ray import tune
  107. >>> @overrides(tune.Trainable) # doctest: +SKIP
  108. ... @OverrideToImplementCustomLogic_CallToSuperRecommended # doctest: +SKIP
  109. ... def setup(self, config): # doctest: +SKIP
  110. ... # implement custom setup logic here ...
  111. ... super().setup(config) # doctest: +SKIP
  112. ... # ... or here (after having called super()'s setup method.
  113. """
  114. obj.__is_overriden__ = False
  115. return obj
  116. def is_overridden(obj):
  117. """Check whether a function has been overridden.
  118. Note, this only works for API calls decorated with OverrideToImplementCustomLogic
  119. or OverrideToImplementCustomLogic_CallToSuperRecommended.
  120. """
  121. return getattr(obj, "__is_overriden__", True)
  122. # Backward compatibility.
  123. Deprecated = Deprecated