__init__.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import contextlib
  2. from functools import partial
  3. from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
  4. from ray.rllib.utils.framework import try_import_tf, try_import_tfp, \
  5. try_import_torch
  6. from ray.rllib.utils.deprecation import deprecation_warning
  7. from ray.rllib.utils.filter_manager import FilterManager
  8. from ray.rllib.utils.filter import Filter
  9. from ray.rllib.utils.numpy import sigmoid, softmax, relu, one_hot, fc, lstm, \
  10. SMALL_NUMBER, LARGE_INTEGER, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT
  11. from ray.rllib.utils.pre_checks.env import check_env
  12. from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \
  13. PolynomialSchedule, ExponentialSchedule, ConstantSchedule
  14. from ray.rllib.utils.test_utils import check, check_compute_single_action, \
  15. check_train_results, framework_iterator
  16. from ray.tune.utils import merge_dicts, deep_update
  17. def add_mixins(base, mixins, reversed=False):
  18. """Returns a new class with mixins applied in priority order."""
  19. mixins = list(mixins or [])
  20. while mixins:
  21. if reversed:
  22. class new_base(base, mixins.pop()):
  23. pass
  24. else:
  25. class new_base(mixins.pop(), base):
  26. pass
  27. base = new_base
  28. return base
  29. def force_list(elements=None, to_tuple=False):
  30. """
  31. Makes sure `elements` is returned as a list, whether `elements` is a single
  32. item, already a list, or a tuple.
  33. Args:
  34. elements (Optional[any]): The inputs as single item, list, or tuple to
  35. be converted into a list/tuple. If None, returns empty list/tuple.
  36. to_tuple (bool): Whether to use tuple (instead of list).
  37. Returns:
  38. Union[list,tuple]: All given elements in a list/tuple depending on
  39. `to_tuple`'s value. If elements is None,
  40. returns an empty list/tuple.
  41. """
  42. ctor = list
  43. if to_tuple is True:
  44. ctor = tuple
  45. return ctor() if elements is None else ctor(elements) \
  46. if type(elements) in [list, tuple] else ctor([elements])
  47. class NullContextManager(contextlib.AbstractContextManager):
  48. """No-op context manager"""
  49. def __init__(self):
  50. pass
  51. def __enter__(self):
  52. pass
  53. def __exit__(self, *args):
  54. pass
  55. force_tuple = partial(force_list, to_tuple=True)
  56. __all__ = [
  57. "add_mixins",
  58. "check",
  59. "check_env",
  60. "check_compute_single_action",
  61. "check_train_results",
  62. "deep_update",
  63. "deprecation_warning",
  64. "fc",
  65. "force_list",
  66. "force_tuple",
  67. "framework_iterator",
  68. "lstm",
  69. "merge_dicts",
  70. "one_hot",
  71. "override",
  72. "relu",
  73. "sigmoid",
  74. "softmax",
  75. "try_import_tf",
  76. "try_import_tfp",
  77. "try_import_torch",
  78. "ConstantSchedule",
  79. "DeveloperAPI",
  80. "ExponentialSchedule",
  81. "Filter",
  82. "FilterManager",
  83. "LARGE_INTEGER",
  84. "LinearSchedule",
  85. "MAX_LOG_NN_OUTPUT",
  86. "MIN_LOG_NN_OUTPUT",
  87. "PiecewiseSchedule",
  88. "PolynomialSchedule",
  89. "PublicAPI",
  90. "SMALL_NUMBER",
  91. ]