test_models.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from gymnasium.spaces import Box
  2. import numpy as np
  3. import unittest
  4. import ray
  5. import ray.rllib.algorithms.ppo as ppo
  6. from ray.rllib.examples.models.modelv3 import RNNModel
  7. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  8. from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
  9. from ray.rllib.utils.framework import try_import_tf
  10. tf1, tf, tfv = try_import_tf()
  11. class TestTFModel(TFModelV2):
  12. def __init__(self, obs_space, action_space, num_outputs, model_config, name):
  13. super().__init__(obs_space, action_space, num_outputs, model_config, name)
  14. input_ = tf.keras.layers.Input(shape=(3,))
  15. output = tf.keras.layers.Dense(2)(input_)
  16. # A keras model inside.
  17. self.keras_model = tf.keras.models.Model([input_], [output])
  18. # A RLlib FullyConnectedNetwork (tf) inside (which is also a keras
  19. # Model).
  20. self.fc_net = FullyConnectedNetwork(obs_space, action_space, 3, {}, "fc1")
  21. def forward(self, input_dict, state, seq_lens):
  22. obs = input_dict["obs_flat"]
  23. out1 = self.keras_model(obs)
  24. out2, _ = self.fc_net({"obs": obs})
  25. return tf.concat([out1, out2], axis=-1), []
  26. class TestModels(unittest.TestCase):
  27. """Tests ModelV2 classes and their modularization capabilities."""
  28. @classmethod
  29. def setUpClass(cls) -> None:
  30. ray.init()
  31. @classmethod
  32. def tearDownClass(cls) -> None:
  33. ray.shutdown()
  34. def test_tf_modelv2(self):
  35. obs_space = Box(-1.0, 1.0, (3,))
  36. action_space = Box(-1.0, 1.0, (2,))
  37. my_tf_model = TestTFModel(obs_space, action_space, 5, {}, "my_tf_model")
  38. # Call the model.
  39. out, states = my_tf_model({"obs": np.array([obs_space.sample()])})
  40. self.assertTrue(out.shape == (1, 5))
  41. self.assertTrue(out.dtype == tf.float32)
  42. self.assertTrue(states == [])
  43. vars = my_tf_model.variables(as_dict=True)
  44. self.assertTrue(len(vars) == 6)
  45. self.assertTrue("keras_model.dense.kernel:0" in vars)
  46. self.assertTrue("keras_model.dense.bias:0" in vars)
  47. self.assertTrue("fc_net.base_model.fc_out.kernel:0" in vars)
  48. self.assertTrue("fc_net.base_model.fc_out.bias:0" in vars)
  49. self.assertTrue("fc_net.base_model.value_out.kernel:0" in vars)
  50. self.assertTrue("fc_net.base_model.value_out.bias:0" in vars)
  51. def test_modelv3(self):
  52. config = (
  53. ppo.PPOConfig()
  54. .environment("CartPole-v1")
  55. .framework("tf")
  56. .rollouts(num_rollout_workers=0)
  57. .training(
  58. model={
  59. "custom_model": RNNModel,
  60. "custom_model_config": {
  61. "hiddens_size": 64,
  62. "cell_size": 128,
  63. },
  64. }
  65. )
  66. )
  67. algo = config.build()
  68. for _ in range(2):
  69. results = algo.train()
  70. print(results)
  71. algo.stop()
  72. if __name__ == "__main__":
  73. import pytest
  74. import sys
  75. sys.exit(pytest.main(["-v", __file__]))