test_dnc.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import gym
  2. import unittest
  3. import ray
  4. from ray import tune
  5. from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
  6. from ray.rllib.models.catalog import ModelCatalog
  7. from ray.rllib.examples.models.neural_computer import DNCMemory
  8. from ray.rllib.utils.framework import try_import_torch
  9. torch, _ = try_import_torch()
  10. class TestDNC(unittest.TestCase):
  11. stop = {
  12. "episode_reward_mean": 100.0,
  13. "timesteps_total": 10000000,
  14. }
  15. @classmethod
  16. def setUpClass(cls) -> None:
  17. ray.init(num_cpus=4, ignore_reinit_error=True)
  18. @classmethod
  19. def tearDownClass(cls) -> None:
  20. ray.shutdown()
  21. def test_pack_unpack(self):
  22. d = DNCMemory(
  23. gym.spaces.Discrete(1), gym.spaces.Discrete(1), 1, {}, "")
  24. # Add batch dim
  25. packed_state = [m.unsqueeze(0) for m in d.get_initial_state()]
  26. [m.random_() for m in packed_state]
  27. original_packed = [m.clone() for m in packed_state]
  28. B, T = packed_state[0].shape[:2]
  29. unpacked = d.unpack_state(packed_state)
  30. packed = d.pack_state(*unpacked)
  31. self.assertTrue(len(packed) > 0)
  32. self.assertEqual(len(packed), len(original_packed))
  33. for m_idx in range(len(packed)):
  34. self.assertTrue(torch.all(packed[m_idx] == original_packed[m_idx]))
  35. def test_dnc_learning(self):
  36. ModelCatalog.register_custom_model("dnc", DNCMemory)
  37. config = {
  38. "env": StatelessCartPole,
  39. "gamma": 0.99,
  40. "num_envs_per_worker": 5,
  41. "framework": "torch",
  42. "num_workers": 1,
  43. "num_cpus_per_worker": 2.0,
  44. "lr": 0.01,
  45. "entropy_coeff": 0.0005,
  46. "vf_loss_coeff": 1e-5,
  47. "model": {
  48. "custom_model": "dnc",
  49. "max_seq_len": 64,
  50. "custom_model_config": {
  51. "nr_cells": 10,
  52. "cell_size": 8,
  53. },
  54. },
  55. }
  56. tune.run("A2C", config=config, stop=self.stop, verbose=1)
  57. if __name__ == "__main__":
  58. import pytest
  59. import sys
  60. sys.exit(pytest.main(["-v", __file__]))