test_catalog.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from functools import partial
  2. import gym
  3. from gym.spaces import Box, Dict, Discrete
  4. import numpy as np
  5. import unittest
  6. import ray
  7. from ray.rllib.models import ActionDistribution, ModelCatalog, MODEL_DEFAULTS
  8. from ray.rllib.models.preprocessors import NoPreprocessor, Preprocessor
  9. from ray.rllib.models.tf.tf_action_dist import MultiActionDistribution, \
  10. TFActionDistribution
  11. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  12. from ray.rllib.utils.annotations import override
  13. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  14. from ray.rllib.utils.test_utils import framework_iterator
  15. tf1, tf, tfv = try_import_tf()
  16. torch, _ = try_import_torch()
  17. class CustomPreprocessor(Preprocessor):
  18. def _init_shape(self, obs_space, options):
  19. return [1]
  20. class CustomPreprocessor2(Preprocessor):
  21. def _init_shape(self, obs_space, options):
  22. return [1]
  23. class CustomModel(TFModelV2):
  24. def _build_layers(self, *args):
  25. return tf.constant([[0] * 5]), None
  26. class CustomActionDistribution(TFActionDistribution):
  27. def __init__(self, inputs, model):
  28. # Store our output shape.
  29. custom_model_config = model.model_config["custom_model_config"]
  30. if "output_dim" in custom_model_config:
  31. self.output_shape = tf.concat(
  32. [tf.shape(inputs)[:1], custom_model_config["output_dim"]],
  33. axis=0)
  34. else:
  35. self.output_shape = tf.shape(inputs)
  36. super().__init__(inputs, model)
  37. @staticmethod
  38. def required_model_output_shape(action_space, model_config=None):
  39. custom_model_config = model_config["custom_model_config"] or {}
  40. if custom_model_config is not None and \
  41. custom_model_config.get("output_dim"):
  42. return custom_model_config.get("output_dim")
  43. return action_space.shape
  44. @override(TFActionDistribution)
  45. def _build_sample_op(self):
  46. return tf.random.uniform(self.output_shape)
  47. @override(ActionDistribution)
  48. def logp(self, x):
  49. return tf.zeros(self.output_shape)
  50. class CustomMultiActionDistribution(MultiActionDistribution):
  51. @override(MultiActionDistribution)
  52. def entropy(self):
  53. raise NotImplementedError
  54. class TestModelCatalog(unittest.TestCase):
  55. def tearDown(self):
  56. ray.shutdown()
  57. def test_custom_preprocessor(self):
  58. ray.init(object_store_memory=1000 * 1024 * 1024)
  59. ModelCatalog.register_custom_preprocessor("foo", CustomPreprocessor)
  60. ModelCatalog.register_custom_preprocessor("bar", CustomPreprocessor2)
  61. env = gym.make("CartPole-v0")
  62. p1 = ModelCatalog.get_preprocessor(env, {"custom_preprocessor": "foo"})
  63. self.assertEqual(str(type(p1)), str(CustomPreprocessor))
  64. p2 = ModelCatalog.get_preprocessor(env, {"custom_preprocessor": "bar"})
  65. self.assertEqual(str(type(p2)), str(CustomPreprocessor2))
  66. p3 = ModelCatalog.get_preprocessor(env)
  67. self.assertEqual(type(p3), NoPreprocessor)
  68. def test_default_models(self):
  69. ray.init(object_store_memory=1000 * 1024 * 1024)
  70. for fw in framework_iterator(frameworks=("jax", "tf", "tf2", "torch")):
  71. obs_space = Box(0, 1, shape=(3, ), dtype=np.float32)
  72. p1 = ModelCatalog.get_model_v2(
  73. obs_space=obs_space,
  74. action_space=Discrete(5),
  75. num_outputs=5,
  76. model_config={},
  77. framework=fw,
  78. )
  79. self.assertTrue("FullyConnectedNetwork" in type(p1).__name__)
  80. # Do a test forward pass.
  81. obs = np.array([obs_space.sample()])
  82. if fw == "torch":
  83. obs = torch.from_numpy(obs)
  84. out, state_outs = p1({"obs": obs})
  85. self.assertTrue(out.shape == (1, 5))
  86. self.assertTrue(state_outs == [])
  87. # No Conv2Ds for JAX yet.
  88. if fw != "jax":
  89. p2 = ModelCatalog.get_model_v2(
  90. obs_space=Box(0, 1, shape=(84, 84, 3), dtype=np.float32),
  91. action_space=Discrete(5),
  92. num_outputs=5,
  93. model_config={},
  94. framework=fw,
  95. )
  96. self.assertTrue("VisionNetwork" in type(p2).__name__)
  97. def test_custom_model(self):
  98. ray.init(object_store_memory=1000 * 1024 * 1024)
  99. ModelCatalog.register_custom_model("foo", CustomModel)
  100. p1 = ModelCatalog.get_model_v2(
  101. obs_space=Box(0, 1, shape=(3, ), dtype=np.float32),
  102. action_space=Discrete(5),
  103. num_outputs=5,
  104. model_config={"custom_model": "foo"})
  105. self.assertEqual(str(type(p1)), str(CustomModel))
  106. def test_custom_action_distribution(self):
  107. class Model():
  108. pass
  109. ray.init(
  110. object_store_memory=1000 * 1024 * 1024,
  111. ignore_reinit_error=True) # otherwise fails sometimes locally
  112. # registration
  113. ModelCatalog.register_custom_action_dist("test",
  114. CustomActionDistribution)
  115. action_space = Box(0, 1, shape=(5, 3), dtype=np.float32)
  116. # test retrieving it
  117. model_config = MODEL_DEFAULTS.copy()
  118. model_config["custom_action_dist"] = "test"
  119. dist_cls, param_shape = ModelCatalog.get_action_dist(
  120. action_space, model_config)
  121. self.assertEqual(str(dist_cls), str(CustomActionDistribution))
  122. self.assertEqual(param_shape, action_space.shape)
  123. # test the class works as a distribution
  124. dist_input = tf1.placeholder(tf.float32, (None, ) + param_shape)
  125. model = Model()
  126. model.model_config = model_config
  127. dist = dist_cls(dist_input, model=model)
  128. self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
  129. self.assertIsInstance(dist.sample(), tf.Tensor)
  130. with self.assertRaises(NotImplementedError):
  131. dist.entropy()
  132. # test passing the options to it
  133. model_config["custom_model_config"].update({"output_dim": (3, )})
  134. dist_cls, param_shape = ModelCatalog.get_action_dist(
  135. action_space, model_config)
  136. self.assertEqual(param_shape, (3, ))
  137. dist_input = tf1.placeholder(tf.float32, (None, ) + param_shape)
  138. model.model_config = model_config
  139. dist = dist_cls(dist_input, model=model)
  140. self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
  141. self.assertIsInstance(dist.sample(), tf.Tensor)
  142. with self.assertRaises(NotImplementedError):
  143. dist.entropy()
  144. def test_custom_multi_action_distribution(self):
  145. class Model():
  146. pass
  147. ray.init(
  148. object_store_memory=1000 * 1024 * 1024,
  149. ignore_reinit_error=True) # otherwise fails sometimes locally
  150. # registration
  151. ModelCatalog.register_custom_action_dist(
  152. "test", CustomMultiActionDistribution)
  153. s1 = Discrete(5)
  154. s2 = Box(0, 1, shape=(3, ), dtype=np.float32)
  155. spaces = dict(action_1=s1, action_2=s2)
  156. action_space = Dict(spaces)
  157. # test retrieving it
  158. model_config = MODEL_DEFAULTS.copy()
  159. model_config["custom_action_dist"] = "test"
  160. dist_cls, param_shape = ModelCatalog.get_action_dist(
  161. action_space, model_config)
  162. self.assertIsInstance(dist_cls, partial)
  163. self.assertEqual(param_shape, s1.n + 2 * s2.shape[0])
  164. # test the class works as a distribution
  165. dist_input = tf1.placeholder(tf.float32, (None, param_shape))
  166. model = Model()
  167. model.model_config = model_config
  168. dist = dist_cls(dist_input, model=model)
  169. self.assertIsInstance(dist.sample(), dict)
  170. self.assertIn("action_1", dist.sample())
  171. self.assertIn("action_2", dist.sample())
  172. self.assertEqual(dist.sample()["action_1"].dtype, tf.int64)
  173. self.assertEqual(dist.sample()["action_2"].shape[1:], s2.shape)
  174. with self.assertRaises(NotImplementedError):
  175. dist.entropy()
  176. if __name__ == "__main__":
  177. import pytest
  178. import sys
  179. sys.exit(pytest.main(["-v", __file__]))