__init__.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import contextlib
  2. from functools import partial
  3. from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
  4. from ray.rllib.utils.deprecation import deprecation_warning
  5. from ray.rllib.utils.filter import Filter
  6. from ray.rllib.utils.filter_manager import FilterManager
  7. from ray.rllib.utils.framework import (
  8. try_import_jax,
  9. try_import_tf,
  10. try_import_tfp,
  11. try_import_torch,
  12. )
  13. from ray.rllib.utils.numpy import (
  14. sigmoid,
  15. softmax,
  16. relu,
  17. one_hot,
  18. fc,
  19. lstm,
  20. SMALL_NUMBER,
  21. LARGE_INTEGER,
  22. MIN_LOG_NN_OUTPUT,
  23. MAX_LOG_NN_OUTPUT,
  24. )
  25. from ray.rllib.utils.pre_checks.env import check_env
  26. from ray.rllib.utils.schedules import (
  27. LinearSchedule,
  28. PiecewiseSchedule,
  29. PolynomialSchedule,
  30. ExponentialSchedule,
  31. ConstantSchedule,
  32. )
  33. from ray.rllib.utils.test_utils import (
  34. check,
  35. check_compute_single_action,
  36. check_train_results,
  37. framework_iterator,
  38. )
  39. from ray.tune.utils import merge_dicts, deep_update
  40. @DeveloperAPI
  41. def add_mixins(base, mixins, reversed=False):
  42. """Returns a new class with mixins applied in priority order."""
  43. mixins = list(mixins or [])
  44. while mixins:
  45. if reversed:
  46. class new_base(base, mixins.pop()):
  47. pass
  48. else:
  49. class new_base(mixins.pop(), base):
  50. pass
  51. base = new_base
  52. return base
  53. @DeveloperAPI
  54. def force_list(elements=None, to_tuple=False):
  55. """
  56. Makes sure `elements` is returned as a list, whether `elements` is a single
  57. item, already a list, or a tuple.
  58. Args:
  59. elements (Optional[any]): The inputs as single item, list, or tuple to
  60. be converted into a list/tuple. If None, returns empty list/tuple.
  61. to_tuple: Whether to use tuple (instead of list).
  62. Returns:
  63. Union[list,tuple]: All given elements in a list/tuple depending on
  64. `to_tuple`'s value. If elements is None,
  65. returns an empty list/tuple.
  66. """
  67. ctor = list
  68. if to_tuple is True:
  69. ctor = tuple
  70. return (
  71. ctor()
  72. if elements is None
  73. else ctor(elements)
  74. if type(elements) in [list, set, tuple]
  75. else ctor([elements])
  76. )
  77. @DeveloperAPI
  78. class NullContextManager(contextlib.AbstractContextManager):
  79. """No-op context manager"""
  80. def __init__(self):
  81. pass
  82. def __enter__(self):
  83. pass
  84. def __exit__(self, *args):
  85. pass
  86. force_tuple = partial(force_list, to_tuple=True)
  87. __all__ = [
  88. "add_mixins",
  89. "check",
  90. "check_env",
  91. "check_compute_single_action",
  92. "check_train_results",
  93. "deep_update",
  94. "deprecation_warning",
  95. "fc",
  96. "force_list",
  97. "force_tuple",
  98. "framework_iterator",
  99. "lstm",
  100. "merge_dicts",
  101. "one_hot",
  102. "override",
  103. "relu",
  104. "sigmoid",
  105. "softmax",
  106. "try_import_jax",
  107. "try_import_tf",
  108. "try_import_tfp",
  109. "try_import_torch",
  110. "ConstantSchedule",
  111. "DeveloperAPI",
  112. "ExponentialSchedule",
  113. "Filter",
  114. "FilterManager",
  115. "LARGE_INTEGER",
  116. "LinearSchedule",
  117. "MAX_LOG_NN_OUTPUT",
  118. "MIN_LOG_NN_OUTPUT",
  119. "PiecewiseSchedule",
  120. "PolynomialSchedule",
  121. "PublicAPI",
  122. "SMALL_NUMBER",
  123. ]