test_supported_spaces.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. from gym.spaces import Box, Dict, Discrete, Tuple, MultiDiscrete
  2. import numpy as np
  3. import unittest
  4. import ray
  5. from ray.rllib.agents.registry import get_trainer_class
  6. from ray.rllib.examples.env.random_env import RandomEnv
  7. from ray.rllib.models.tf.complex_input_net import ComplexInputNetwork as \
  8. ComplexNet
  9. from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as FCNet
  10. from ray.rllib.models.tf.visionnet import VisionNetwork as VisionNet
  11. from ray.rllib.models.torch.complex_input_net import ComplexInputNetwork as \
  12. TorchComplexNet
  13. from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNet
  14. from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVisionNet
  15. from ray.rllib.utils.error import UnsupportedSpaceException
  16. from ray.rllib.utils.test_utils import framework_iterator
  17. ACTION_SPACES_TO_TEST = {
  18. "discrete": Discrete(5),
  19. "vector1d": Box(-1.0, 1.0, (5, ), dtype=np.float32),
  20. "vector2d": Box(-1.0, 1.0, (5, ), dtype=np.float32),
  21. "int_actions": Box(0, 3, (2, 3), dtype=np.int32),
  22. "multidiscrete": MultiDiscrete([1, 2, 3, 4]),
  23. "tuple": Tuple(
  24. [Discrete(2),
  25. Discrete(3),
  26. Box(-1.0, 1.0, (5, ), dtype=np.float32)]),
  27. "dict": Dict({
  28. "action_choice": Discrete(3),
  29. "parameters": Box(-1.0, 1.0, (1, ), dtype=np.float32),
  30. "yet_another_nested_dict": Dict({
  31. "a": Tuple([Discrete(2), Discrete(3)])
  32. })
  33. }),
  34. }
  35. OBSERVATION_SPACES_TO_TEST = {
  36. "discrete": Discrete(5),
  37. "vector1d": Box(-1.0, 1.0, (5, ), dtype=np.float32),
  38. "vector2d": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
  39. "image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32),
  40. "vizdoomgym": Box(-1.0, 1.0, (240, 320, 3), dtype=np.float32),
  41. "tuple": Tuple([Discrete(10),
  42. Box(-1.0, 1.0, (5, ), dtype=np.float32)]),
  43. "dict": Dict({
  44. "task": Discrete(10),
  45. "position": Box(-1.0, 1.0, (5, ), dtype=np.float32),
  46. }),
  47. }
  48. def check_support(alg, config, train=True, check_bounds=False, tfe=False):
  49. config["log_level"] = "ERROR"
  50. config["train_batch_size"] = 10
  51. config["rollout_fragment_length"] = 10
  52. def _do_check(alg, config, a_name, o_name):
  53. fw = config["framework"]
  54. action_space = ACTION_SPACES_TO_TEST[a_name]
  55. obs_space = OBSERVATION_SPACES_TO_TEST[o_name]
  56. print("=== Testing {} (fw={}) A={} S={} ===".format(
  57. alg, fw, action_space, obs_space))
  58. config.update(
  59. dict(
  60. env_config=dict(
  61. action_space=action_space,
  62. observation_space=obs_space,
  63. reward_space=Box(1.0, 1.0, shape=(), dtype=np.float32),
  64. p_done=1.0,
  65. check_action_bounds=check_bounds)))
  66. stat = "ok"
  67. try:
  68. a = get_trainer_class(alg)(config=config, env=RandomEnv)
  69. except ray.exceptions.RayActorError as e:
  70. if isinstance(e.args[2], UnsupportedSpaceException):
  71. stat = "unsupported"
  72. else:
  73. raise
  74. except UnsupportedSpaceException:
  75. stat = "unsupported"
  76. else:
  77. if alg not in ["DDPG", "ES", "ARS", "SAC"]:
  78. # 2D (image) input: Expect VisionNet.
  79. if o_name in ["atari", "image"]:
  80. if fw == "torch":
  81. assert isinstance(a.get_policy().model, TorchVisionNet)
  82. else:
  83. assert isinstance(a.get_policy().model, VisionNet)
  84. # 1D input: Expect FCNet.
  85. elif o_name == "vector1d":
  86. if fw == "torch":
  87. assert isinstance(a.get_policy().model, TorchFCNet)
  88. else:
  89. assert isinstance(a.get_policy().model, FCNet)
  90. # Could be either one: ComplexNet (if disabled Preprocessor)
  91. # or FCNet (w/ Preprocessor).
  92. elif o_name == "vector2d":
  93. if fw == "torch":
  94. assert isinstance(a.get_policy().model,
  95. (TorchComplexNet, TorchFCNet))
  96. else:
  97. assert isinstance(a.get_policy().model,
  98. (ComplexNet, FCNet))
  99. if train:
  100. a.train()
  101. a.stop()
  102. print(stat)
  103. frameworks = ("tf", "torch")
  104. if tfe:
  105. frameworks += ("tf2", "tfe")
  106. for _ in framework_iterator(config, frameworks=frameworks):
  107. # Zip through action- and obs-spaces.
  108. for a_name, o_name in zip(ACTION_SPACES_TO_TEST.keys(),
  109. OBSERVATION_SPACES_TO_TEST.keys()):
  110. _do_check(alg, config, a_name, o_name)
  111. # Do the remaining obs spaces.
  112. assert len(OBSERVATION_SPACES_TO_TEST) >= len(ACTION_SPACES_TO_TEST)
  113. fixed_action_key = next(iter(ACTION_SPACES_TO_TEST.keys()))
  114. for i, o_name in enumerate(OBSERVATION_SPACES_TO_TEST.keys()):
  115. if i < len(ACTION_SPACES_TO_TEST):
  116. continue
  117. _do_check(alg, config, fixed_action_key, o_name)
  118. class TestSupportedSpacesPG(unittest.TestCase):
  119. @classmethod
  120. def setUpClass(cls) -> None:
  121. ray.init()
  122. @classmethod
  123. def tearDownClass(cls) -> None:
  124. ray.shutdown()
  125. def test_a3c(self):
  126. config = {"num_workers": 1, "optimizer": {"grads_per_step": 1}}
  127. check_support("A3C", config, check_bounds=True)
  128. def test_appo(self):
  129. check_support("APPO", {"num_gpus": 0, "vtrace": False}, train=False)
  130. check_support("APPO", {"num_gpus": 0, "vtrace": True})
  131. def test_impala(self):
  132. check_support("IMPALA", {"num_gpus": 0})
  133. def test_ppo(self):
  134. config = {
  135. "num_workers": 0,
  136. "train_batch_size": 100,
  137. "rollout_fragment_length": 10,
  138. "num_sgd_iter": 1,
  139. "sgd_minibatch_size": 10,
  140. }
  141. check_support("PPO", config, check_bounds=True, tfe=True)
  142. def test_pg(self):
  143. config = {"num_workers": 1, "optimizer": {}}
  144. check_support("PG", config, train=False, check_bounds=True, tfe=True)
  145. class TestSupportedSpacesOffPolicy(unittest.TestCase):
  146. @classmethod
  147. def setUpClass(cls) -> None:
  148. ray.init(num_cpus=4)
  149. @classmethod
  150. def tearDownClass(cls) -> None:
  151. ray.shutdown()
  152. def test_ddpg(self):
  153. check_support(
  154. "DDPG", {
  155. "exploration_config": {
  156. "ou_base_scale": 100.0
  157. },
  158. "timesteps_per_iteration": 1,
  159. "buffer_size": 1000,
  160. "use_state_preprocessor": True,
  161. },
  162. check_bounds=True)
  163. def test_dqn(self):
  164. config = {"timesteps_per_iteration": 1, "buffer_size": 1000}
  165. check_support("DQN", config, tfe=True)
  166. def test_sac(self):
  167. check_support("SAC", {"buffer_size": 1000}, check_bounds=True)
  168. class TestSupportedSpacesEvolutionAlgos(unittest.TestCase):
  169. @classmethod
  170. def setUpClass(cls) -> None:
  171. ray.init(num_cpus=4)
  172. @classmethod
  173. def tearDownClass(cls) -> None:
  174. ray.shutdown()
  175. def test_ars(self):
  176. check_support(
  177. "ARS", {
  178. "num_workers": 1,
  179. "noise_size": 1500000,
  180. "num_rollouts": 1,
  181. "rollouts_used": 1
  182. })
  183. def test_es(self):
  184. check_support(
  185. "ES", {
  186. "num_workers": 1,
  187. "noise_size": 1500000,
  188. "episodes_per_batch": 1,
  189. "train_batch_size": 1
  190. })
  191. if __name__ == "__main__":
  192. import pytest
  193. import sys
  194. # One can specify the specific TestCase class to run.
  195. # None for all unittest.TestCase classes in this file.
  196. class_ = sys.argv[1] if len(sys.argv) > 1 else None
  197. sys.exit(
  198. pytest.main(
  199. ["-v", __file__ + ("" if class_ is None else "::" + class_)]))