test_catalog.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. from functools import partial
  2. from gymnasium.spaces import Box, Dict, Discrete, Tuple
  3. import numpy as np
  4. import unittest
  5. import ray
  6. from ray.rllib.models import ActionDistribution, ModelCatalog, MODEL_DEFAULTS
  7. from ray.rllib.models.preprocessors import (
  8. Preprocessor,
  9. TupleFlatteningPreprocessor,
  10. )
  11. from ray.rllib.models.tf.tf_action_dist import (
  12. MultiActionDistribution,
  13. TFActionDistribution,
  14. )
  15. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  16. from ray.rllib.utils.annotations import override
  17. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  18. from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
  19. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  20. tf1, tf, tfv = try_import_tf()
  21. torch, _ = try_import_torch()
  22. class CustomPreprocessor(Preprocessor):
  23. def _init_shape(self, obs_space, options):
  24. return [1]
  25. class CustomPreprocessor2(Preprocessor):
  26. def _init_shape(self, obs_space, options):
  27. return [1]
  28. class CustomModel(TFModelV2):
  29. def _build_layers(self, *args):
  30. return tf.constant([[0] * 5]), None
  31. class CustomActionDistribution(TFActionDistribution):
  32. def __init__(self, inputs, model):
  33. # Store our output shape.
  34. custom_model_config = model.model_config["custom_model_config"]
  35. if "output_dim" in custom_model_config:
  36. self.output_shape = tf.concat(
  37. [tf.shape(inputs)[:1], custom_model_config["output_dim"]], axis=0
  38. )
  39. else:
  40. self.output_shape = tf.shape(inputs)
  41. super().__init__(inputs, model)
  42. @staticmethod
  43. def required_model_output_shape(action_space, model_config=None):
  44. custom_model_config = model_config["custom_model_config"] or {}
  45. if custom_model_config is not None and custom_model_config.get("output_dim"):
  46. return custom_model_config.get("output_dim")
  47. return action_space.shape
  48. @override(TFActionDistribution)
  49. def _build_sample_op(self):
  50. return tf.random.uniform(self.output_shape)
  51. @override(ActionDistribution)
  52. def logp(self, x):
  53. return tf.zeros(self.output_shape)
  54. class CustomMultiActionDistribution(MultiActionDistribution):
  55. @override(MultiActionDistribution)
  56. def entropy(self):
  57. raise NotImplementedError
  58. class TestModelCatalog(unittest.TestCase):
  59. def tearDown(self):
  60. ray.shutdown()
  61. def test_default_models(self):
  62. ray.init(object_store_memory=1000 * 1024 * 1024)
  63. # Build test cases
  64. flat_input_case = {
  65. "obs_space": Box(0, 1, shape=(3,), dtype=np.float32),
  66. "action_space": Box(0, 1, shape=(4,)),
  67. "num_outputs": 4,
  68. "expected_model": "FullyConnectedNetwork",
  69. }
  70. img_input_case = {
  71. "obs_space": Box(0, 1, shape=(84, 84, 3), dtype=np.float32),
  72. "action_space": Discrete(5),
  73. "num_outputs": 5,
  74. "expected_model": "VisionNetwork",
  75. }
  76. complex_obs_space = Tuple(
  77. [
  78. Box(0, 1, shape=(3,), dtype=np.float32),
  79. Box(0, 1, shape=(4,), dtype=np.float32),
  80. Discrete(3),
  81. ]
  82. )
  83. obs_prep = TupleFlatteningPreprocessor(complex_obs_space)
  84. flat_complex_input_case = {
  85. "obs_space": obs_prep.observation_space,
  86. "action_space": Box(0, 1, shape=(5,)),
  87. "num_outputs": 5,
  88. "expected_model": "FullyConnectedNetwork",
  89. }
  90. nested_complex_input_case = {
  91. "obs_space": Tuple(
  92. [
  93. Box(0, 1, shape=(3,), dtype=np.float32),
  94. Discrete(3),
  95. Tuple(
  96. [
  97. Box(0, 1, shape=(84, 84, 3), dtype=np.float32),
  98. Box(0, 1, shape=(84, 84, 3), dtype=np.float32),
  99. ]
  100. ),
  101. ]
  102. ),
  103. "action_space": Box(0, 1, shape=(7,)),
  104. "num_outputs": 7,
  105. "expected_model": "ComplexInputNetwork",
  106. }
  107. # Define which tests to run per framework
  108. test_suite = {
  109. "tf": [
  110. flat_input_case,
  111. img_input_case,
  112. flat_complex_input_case,
  113. nested_complex_input_case,
  114. ],
  115. "tf2": [
  116. flat_input_case,
  117. img_input_case,
  118. flat_complex_input_case,
  119. nested_complex_input_case,
  120. ],
  121. "torch": [
  122. flat_input_case,
  123. img_input_case,
  124. flat_complex_input_case,
  125. nested_complex_input_case,
  126. ],
  127. }
  128. for fw, test_cases in test_suite.items():
  129. for test in test_cases:
  130. model_config = {}
  131. if test["expected_model"] == "ComplexInputNetwork":
  132. model_config["fcnet_hiddens"] = [256, 256]
  133. m = ModelCatalog.get_model_v2(
  134. obs_space=test["obs_space"],
  135. action_space=test["action_space"],
  136. num_outputs=test["num_outputs"],
  137. model_config=model_config,
  138. framework=fw,
  139. )
  140. self.assertTrue(test["expected_model"] in type(m).__name__)
  141. # Do a test forward pass.
  142. batch_size = 16
  143. obs = get_dummy_batch_for_space(
  144. test["obs_space"],
  145. batch_size=batch_size,
  146. fill_value="random",
  147. )
  148. if fw == "torch":
  149. obs = convert_to_torch_tensor(obs)
  150. out, state_outs = m({"obs": obs})
  151. self.assertTrue(out.shape == (batch_size, test["num_outputs"]))
  152. self.assertTrue(state_outs == [])
  153. def test_custom_model(self):
  154. ray.init(object_store_memory=1000 * 1024 * 1024)
  155. ModelCatalog.register_custom_model("foo", CustomModel)
  156. p1 = ModelCatalog.get_model_v2(
  157. obs_space=Box(0, 1, shape=(3,), dtype=np.float32),
  158. action_space=Discrete(5),
  159. num_outputs=5,
  160. model_config={"custom_model": "foo"},
  161. )
  162. self.assertEqual(str(type(p1)), str(CustomModel))
  163. def test_custom_action_distribution(self):
  164. class Model:
  165. pass
  166. ray.init(
  167. object_store_memory=1000 * 1024 * 1024, ignore_reinit_error=True
  168. ) # otherwise fails sometimes locally
  169. # registration
  170. ModelCatalog.register_custom_action_dist("test", CustomActionDistribution)
  171. action_space = Box(0, 1, shape=(5, 3), dtype=np.float32)
  172. # test retrieving it
  173. model_config = MODEL_DEFAULTS.copy()
  174. model_config["custom_action_dist"] = "test"
  175. dist_cls, param_shape = ModelCatalog.get_action_dist(action_space, model_config)
  176. self.assertEqual(str(dist_cls), str(CustomActionDistribution))
  177. self.assertEqual(param_shape, action_space.shape)
  178. # test the class works as a distribution
  179. dist_input = tf1.placeholder(tf.float32, (None,) + param_shape)
  180. model = Model()
  181. model.model_config = model_config
  182. dist = dist_cls(dist_input, model=model)
  183. self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
  184. self.assertIsInstance(dist.sample(), tf.Tensor)
  185. with self.assertRaises(NotImplementedError):
  186. dist.entropy()
  187. # test passing the options to it
  188. model_config["custom_model_config"].update({"output_dim": (3,)})
  189. dist_cls, param_shape = ModelCatalog.get_action_dist(action_space, model_config)
  190. self.assertEqual(param_shape, (3,))
  191. dist_input = tf1.placeholder(tf.float32, (None,) + param_shape)
  192. model.model_config = model_config
  193. dist = dist_cls(dist_input, model=model)
  194. self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
  195. self.assertIsInstance(dist.sample(), tf.Tensor)
  196. with self.assertRaises(NotImplementedError):
  197. dist.entropy()
  198. def test_custom_multi_action_distribution(self):
  199. class Model:
  200. pass
  201. ray.init(
  202. object_store_memory=1000 * 1024 * 1024, ignore_reinit_error=True
  203. ) # otherwise fails sometimes locally
  204. # registration
  205. ModelCatalog.register_custom_action_dist("test", CustomMultiActionDistribution)
  206. s1 = Discrete(5)
  207. s2 = Box(0, 1, shape=(3,), dtype=np.float32)
  208. spaces = dict(action_1=s1, action_2=s2)
  209. action_space = Dict(spaces)
  210. # test retrieving it
  211. model_config = MODEL_DEFAULTS.copy()
  212. model_config["custom_action_dist"] = "test"
  213. dist_cls, param_shape = ModelCatalog.get_action_dist(action_space, model_config)
  214. self.assertIsInstance(dist_cls, partial)
  215. self.assertEqual(param_shape, s1.n + 2 * s2.shape[0])
  216. # test the class works as a distribution
  217. dist_input = tf1.placeholder(tf.float32, (None, param_shape))
  218. model = Model()
  219. model.model_config = model_config
  220. dist = dist_cls(dist_input, model=model)
  221. self.assertIsInstance(dist.sample(), dict)
  222. self.assertIn("action_1", dist.sample())
  223. self.assertIn("action_2", dist.sample())
  224. self.assertEqual(dist.sample()["action_1"].dtype, tf.int64)
  225. self.assertEqual(dist.sample()["action_2"].shape[1:], s2.shape)
  226. with self.assertRaises(NotImplementedError):
  227. dist.entropy()
  228. if __name__ == "__main__":
  229. import pytest
  230. import sys
  231. sys.exit(pytest.main(["-v", __file__]))