test_models.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from gym.spaces import Box
  2. import numpy as np
  3. import unittest
  4. import ray
  5. import ray.rllib.agents.ppo as ppo
  6. from ray.rllib.examples.models.modelv3 import RNNModel
  7. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  8. from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
  9. from ray.rllib.utils.framework import try_import_tf
  10. tf1, tf, tfv = try_import_tf()
  11. class TestTFModel(TFModelV2):
  12. def __init__(self, obs_space, action_space, num_outputs, model_config,
  13. name):
  14. super().__init__(obs_space, action_space, num_outputs, model_config,
  15. name)
  16. input_ = tf.keras.layers.Input(shape=(3, ))
  17. output = tf.keras.layers.Dense(2)(input_)
  18. # A keras model inside.
  19. self.keras_model = tf.keras.models.Model([input_], [output])
  20. # A RLlib FullyConnectedNetwork (tf) inside (which is also a keras
  21. # Model).
  22. self.fc_net = FullyConnectedNetwork(obs_space, action_space, 3, {},
  23. "fc1")
  24. def forward(self, input_dict, state, seq_lens):
  25. obs = input_dict["obs_flat"]
  26. out1 = self.keras_model(obs)
  27. out2, _ = self.fc_net({"obs": obs})
  28. return tf.concat([out1, out2], axis=-1), []
  29. class TestModels(unittest.TestCase):
  30. """Tests ModelV2 classes and their modularization capabilities."""
  31. @classmethod
  32. def setUpClass(cls) -> None:
  33. ray.init()
  34. @classmethod
  35. def tearDownClass(cls) -> None:
  36. ray.shutdown()
  37. def test_tf_modelv2(self):
  38. obs_space = Box(-1.0, 1.0, (3, ))
  39. action_space = Box(-1.0, 1.0, (2, ))
  40. my_tf_model = TestTFModel(obs_space, action_space, 5, {},
  41. "my_tf_model")
  42. # Call the model.
  43. out, states = my_tf_model({"obs": np.array([obs_space.sample()])})
  44. self.assertTrue(out.shape == (1, 5))
  45. self.assertTrue(out.dtype == tf.float32)
  46. self.assertTrue(states == [])
  47. vars = my_tf_model.variables(as_dict=True)
  48. self.assertTrue(len(vars) == 6)
  49. self.assertTrue("keras_model.dense.kernel:0" in vars)
  50. self.assertTrue("keras_model.dense.bias:0" in vars)
  51. self.assertTrue("fc_net.base_model.fc_out.kernel:0" in vars)
  52. self.assertTrue("fc_net.base_model.fc_out.bias:0" in vars)
  53. self.assertTrue("fc_net.base_model.value_out.kernel:0" in vars)
  54. self.assertTrue("fc_net.base_model.value_out.bias:0" in vars)
  55. def test_modelv3(self):
  56. config = {
  57. "env": "CartPole-v0",
  58. "model": {
  59. "custom_model": RNNModel,
  60. "custom_model_config": {
  61. "hiddens_size": 64,
  62. "cell_size": 128,
  63. },
  64. },
  65. "num_workers": 0,
  66. }
  67. trainer = ppo.PPOTrainer(config=config)
  68. for _ in range(2):
  69. results = trainer.train()
  70. print(results)
  71. if __name__ == "__main__":
  72. import pytest
  73. import sys
  74. sys.exit(pytest.main(["-v", __file__]))