test_framework_agnostic_components.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from abc import ABCMeta, abstractmethod
  2. from gym.spaces import Discrete
  3. import numpy as np
  4. from pathlib import Path
  5. import unittest
  6. from ray.rllib.utils.exploration.exploration import Exploration
  7. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  8. from ray.rllib.utils.from_config import from_config
  9. from ray.rllib.utils.test_utils import check, framework_iterator
  10. tf1, tf, tfv = try_import_tf()
  11. torch, _ = try_import_torch()
  12. class DummyComponent:
  13. """A simple class that can be used for testing framework-agnostic logic.
  14. Implements a simple `add()` method for adding a value to
  15. `self.prop_b`.
  16. """
  17. def __init__(self,
  18. prop_a,
  19. prop_b=0.5,
  20. prop_c=None,
  21. framework="tf",
  22. **kwargs):
  23. self.framework = framework
  24. self.prop_a = prop_a
  25. self.prop_b = prop_b
  26. self.prop_c = prop_c or "default"
  27. self.prop_d = kwargs.pop("prop_d", 4)
  28. self.kwargs = kwargs
  29. def add(self, value):
  30. if self.framework == "tf":
  31. return self._add_tf(value)
  32. return self.prop_b + value
  33. def _add_tf(self, value):
  34. return tf.add(self.prop_b, value)
  35. class NonAbstractChildOfDummyComponent(DummyComponent):
  36. pass
  37. class AbstractDummyComponent(DummyComponent, metaclass=ABCMeta):
  38. """Used for testing `from_config()`.
  39. """
  40. @abstractmethod
  41. def some_abstract_method(self):
  42. raise NotImplementedError
  43. class TestFrameWorkAgnosticComponents(unittest.TestCase):
  44. """
  45. Tests the Component base class to implement framework-agnostic functional
  46. units.
  47. """
  48. def test_dummy_components(self):
  49. # Bazel makes it hard to find files specified in `args`
  50. # (and `data`).
  51. # Use the true absolute path.
  52. script_dir = Path(__file__).parent
  53. abs_path = script_dir.absolute()
  54. for fw, sess in framework_iterator(session=True):
  55. fw_ = fw if fw != "tfe" else "tf"
  56. # Try to create from an abstract class w/o default constructor.
  57. # Expect None.
  58. test = from_config({
  59. "type": AbstractDummyComponent,
  60. "framework": fw_
  61. })
  62. check(test, None)
  63. # Create a Component via python API (config dict).
  64. component = from_config(
  65. dict(
  66. type=DummyComponent,
  67. prop_a=1.0,
  68. prop_d="non_default",
  69. framework=fw_))
  70. check(component.prop_d, "non_default")
  71. # Create a tf Component from json file.
  72. config_file = str(abs_path.joinpath("dummy_config.json"))
  73. component = from_config(config_file, framework=fw_)
  74. check(component.prop_c, "default")
  75. check(component.prop_d, 4) # default
  76. value = component.add(3.3)
  77. if sess:
  78. value = sess.run(value)
  79. check(value, 5.3) # prop_b == 2.0
  80. # Create a torch Component from yaml file.
  81. config_file = str(abs_path.joinpath("dummy_config.yml"))
  82. component = from_config(config_file, framework=fw_)
  83. check(component.prop_a, "something else")
  84. check(component.prop_d, 3)
  85. value = component.add(1.2)
  86. if sess:
  87. value = sess.run(value)
  88. check(value, np.array([2.2])) # prop_b == 1.0
  89. # Create tf Component from json-string (e.g. on command line).
  90. component = from_config(
  91. '{"type": "ray.rllib.utils.tests.'
  92. 'test_framework_agnostic_components.DummyComponent", '
  93. '"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default", '
  94. '"framework": "' + fw_ + '"}')
  95. check(component.prop_a, "A")
  96. check(component.prop_d, 4) # default
  97. value = component.add(-1.1)
  98. if sess:
  99. value = sess.run(value)
  100. check(value, -2.1) # prop_b == -1.0
  101. # Test recognizing default module path.
  102. component = from_config(
  103. DummyComponent, '{"type": "NonAbstractChildOfDummyComponent", '
  104. '"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default",'
  105. '"framework": "' + fw_ + '"}')
  106. check(component.prop_a, "A")
  107. check(component.prop_d, 4) # default
  108. value = component.add(-1.1)
  109. if sess:
  110. value = sess.run(value)
  111. check(value, -2.1) # prop_b == -1.0
  112. # Test recognizing default package path.
  113. scope = None
  114. if sess:
  115. scope = tf1.variable_scope("exploration_object")
  116. scope.__enter__()
  117. component = from_config(
  118. Exploration, {
  119. "type": "EpsilonGreedy",
  120. "action_space": Discrete(2),
  121. "framework": fw_,
  122. "num_workers": 0,
  123. "worker_index": 0,
  124. "policy_config": {},
  125. "model": None
  126. })
  127. if scope:
  128. scope.__exit__(None, None, None)
  129. check(component.epsilon_schedule.outside_value, 0.05) # default
  130. # Create torch Component from yaml-string.
  131. component = from_config(
  132. "type: ray.rllib.utils.tests."
  133. "test_framework_agnostic_components.DummyComponent\n"
  134. "prop_a: B\nprop_b: -1.5\nprop_c: non-default\nframework: "
  135. "{}".format(fw_))
  136. check(component.prop_a, "B")
  137. check(component.prop_d, 4) # default
  138. value = component.add(-5.1)
  139. if sess:
  140. value = sess.run(value)
  141. check(value, np.array([-6.6])) # prop_b == -1.5
  142. def test_unregistered_envs(self):
  143. """Tests, whether an Env can be specified simply by its absolute class.
  144. """
  145. env_cls = "ray.rllib.examples.env.stateless_cartpole.StatelessCartPole"
  146. env = from_config(env_cls, {"config": 42.0})
  147. state = env.reset()
  148. self.assertTrue(state.shape == (2, ))
  149. if __name__ == "__main__":
  150. import pytest
  151. import sys
  152. sys.exit(pytest.main(["-v", __file__]))