test_nested_action_spaces.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
  2. import numpy as np
  3. import os
  4. import shutil
  5. import tree # pip install dm_tree
  6. import unittest
  7. import ray
  8. from ray.rllib.agents.marwil import BCTrainer
  9. from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG
  10. from ray.rllib.examples.env.random_env import RandomEnv
  11. from ray.rllib.offline.json_reader import JsonReader
  12. from ray.rllib.utils.test_utils import framework_iterator
  13. SPACES = {
  14. "dict": Dict({
  15. "a": Dict({
  16. "aa": Box(-1.0, 1.0, shape=(3, )),
  17. "ab": MultiDiscrete([4, 3]),
  18. }),
  19. "b": Discrete(3),
  20. "c": Tuple([Box(0, 10, (2, ), dtype=np.int32),
  21. Discrete(2)]),
  22. "d": Box(0, 3, (), dtype=np.int64),
  23. }),
  24. "tuple": Tuple([
  25. Tuple([
  26. Box(-1.0, 1.0, shape=(2, )),
  27. Discrete(3),
  28. ]),
  29. MultiDiscrete([4, 3]),
  30. Dict({
  31. "a": Box(0, 100, (), dtype=np.int32),
  32. "b": Discrete(2),
  33. }),
  34. ]),
  35. "multidiscrete": MultiDiscrete([2, 3, 4]),
  36. "intbox": Box(0, 100, (2, ), dtype=np.int32),
  37. }
  38. class NestedActionSpacesTest(unittest.TestCase):
  39. @classmethod
  40. def setUpClass(cls):
  41. ray.init(num_cpus=5)
  42. @classmethod
  43. def tearDownClass(cls):
  44. ray.shutdown()
  45. def test_nested_action_spaces(self):
  46. config = DEFAULT_CONFIG.copy()
  47. config["env"] = RandomEnv
  48. # Write output to check, whether actions are written correctly.
  49. tmp_dir = os.popen("mktemp -d").read()[:-1]
  50. if not os.path.exists(tmp_dir):
  51. # Last resort: Resolve via underlying tempdir (and cut tmp_.
  52. tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:]
  53. assert os.path.exists(tmp_dir), f"'{tmp_dir}' not found!"
  54. config["output"] = tmp_dir
  55. # Switch off OPE as we don't write action-probs.
  56. # TODO: We should probably always write those if `output` is given.
  57. config["input_evaluation"] = []
  58. # Pretend actions in offline files are already normalized.
  59. config["actions_in_input_normalized"] = True
  60. for _ in framework_iterator(config):
  61. for name, action_space in SPACES.items():
  62. config["env_config"] = {
  63. "action_space": action_space,
  64. }
  65. for flatten in [False, True]:
  66. print(f"A={action_space} flatten={flatten}")
  67. shutil.rmtree(config["output"])
  68. config["_disable_action_flattening"] = not flatten
  69. trainer = PGTrainer(config)
  70. trainer.train()
  71. trainer.stop()
  72. # Check actions in output file (whether properly flattened
  73. # or not).
  74. reader = JsonReader(
  75. inputs=config["output"],
  76. ioctx=trainer.workers.local_worker().io_context)
  77. sample_batch = reader.next()
  78. if flatten:
  79. assert isinstance(sample_batch["actions"], np.ndarray)
  80. assert len(sample_batch["actions"].shape) == 2
  81. assert sample_batch["actions"].shape[0] == len(
  82. sample_batch)
  83. else:
  84. tree.assert_same_structure(
  85. trainer.get_policy().action_space_struct,
  86. sample_batch["actions"])
  87. # Test, whether offline data can be properly read by a
  88. # BCTrainer, configured accordingly.
  89. config["input"] = config["output"]
  90. del config["output"]
  91. bc_trainer = BCTrainer(config=config)
  92. bc_trainer.train()
  93. bc_trainer.stop()
  94. config["output"] = tmp_dir
  95. config["input"] = "sampler"