test_dnc.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import gymnasium as gym
  2. import unittest
  3. import ray
  4. from ray import air
  5. from ray import tune
  6. from ray.rllib.algorithms.a2c import A2CConfig
  7. from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
  8. from ray.rllib.models.catalog import ModelCatalog
  9. from ray.rllib.examples.models.neural_computer import DNCMemory
  10. from ray.rllib.utils.framework import try_import_torch
  11. torch, _ = try_import_torch()
  12. class TestDNC(unittest.TestCase):
  13. stop = {
  14. "episode_reward_mean": 100.0,
  15. "timesteps_total": 10000000,
  16. }
  17. @classmethod
  18. def setUpClass(cls) -> None:
  19. ray.init(num_cpus=4, ignore_reinit_error=True)
  20. @classmethod
  21. def tearDownClass(cls) -> None:
  22. ray.shutdown()
  23. def test_pack_unpack(self):
  24. d = DNCMemory(gym.spaces.Discrete(1), gym.spaces.Discrete(1), 1, {}, "")
  25. # Add batch dim
  26. packed_state = [m.unsqueeze(0) for m in d.get_initial_state()]
  27. [m.random_() for m in packed_state]
  28. original_packed = [m.clone() for m in packed_state]
  29. B, T = packed_state[0].shape[:2]
  30. unpacked = d.unpack_state(packed_state)
  31. packed = d.pack_state(*unpacked)
  32. self.assertTrue(len(packed) > 0)
  33. self.assertEqual(len(packed), len(original_packed))
  34. for m_idx in range(len(packed)):
  35. self.assertTrue(torch.all(packed[m_idx] == original_packed[m_idx]))
  36. def test_dnc_learning(self):
  37. ModelCatalog.register_custom_model("dnc", DNCMemory)
  38. config = (
  39. A2CConfig()
  40. .environment(StatelessCartPole)
  41. .framework("torch")
  42. .rollouts(num_envs_per_worker=5, num_rollout_workers=1)
  43. .training(
  44. gamma=0.99,
  45. lr=0.01,
  46. entropy_coeff=0.0005,
  47. vf_loss_coeff=1e-5,
  48. model={
  49. "custom_model": "dnc",
  50. "max_seq_len": 64,
  51. "custom_model_config": {
  52. "nr_cells": 10,
  53. "cell_size": 8,
  54. },
  55. },
  56. )
  57. .resources(num_cpus_per_worker=2.0)
  58. )
  59. tune.Tuner(
  60. "A2C",
  61. param_space=config,
  62. run_config=air.RunConfig(stop=self.stop, verbose=1),
  63. ).fit()
  64. if __name__ == "__main__":
  65. import pytest
  66. import sys
  67. sys.exit(pytest.main(["-v", __file__]))