test_gym_env_apis.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import unittest
  2. import ray
  3. from ray.rllib.algorithms.ppo import PPOConfig
  4. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  5. from ray.rllib.env.wrappers.multi_agent_env_compatibility import (
  6. MultiAgentEnvCompatibility,
  7. )
  8. from ray.rllib.utils.gym import try_import_gymnasium_and_gym
  9. from ray.tune.registry import register_env
  10. gym, old_gym = try_import_gymnasium_and_gym()
  11. class GymnasiumOldAPI(gym.Env):
  12. def __init__(self, config=None):
  13. self.observation_space = gym.spaces.Box(-1.0, 1.0, (1,))
  14. self.action_space = gym.spaces.Discrete(2)
  15. def reset(self):
  16. return self.observation_space.sample()
  17. def step(self, action):
  18. done = True
  19. return self.observation_space.sample(), 1.0, done, {}
  20. def seed(self, seed=None):
  21. pass
  22. def render(self, mode="human"):
  23. pass
  24. class GymnasiumNewAPIButOldSpaces(gym.Env):
  25. render_mode = "human"
  26. def __init__(self, config=None):
  27. self.observation_space = old_gym.spaces.Box(-1.0, 1.0, (1,))
  28. self.action_space = old_gym.spaces.Discrete(2)
  29. def reset(self, *, seed=None, options=None):
  30. return self.observation_space.sample(), {}
  31. def step(self, action):
  32. terminated = truncated = True
  33. return self.observation_space.sample(), 1.0, terminated, truncated, {}
  34. def render(self):
  35. pass
  36. class GymnasiumNewAPIButThrowsErrorOnReset(gym.Env):
  37. render_mode = "human"
  38. def __init__(self, config=None):
  39. self.observation_space = gym.spaces.Box(-1.0, 1.0, (1,))
  40. self.action_space = gym.spaces.Discrete(2)
  41. def reset(self, *, seed=None, options=None):
  42. assert False, "kaboom!"
  43. return self.observation_space.sample(), {}
  44. def step(self, action):
  45. terminated = truncated = True
  46. return self.observation_space.sample(), 1.0, terminated, truncated, {}
  47. def render(self):
  48. pass
  49. class OldGymEnv(old_gym.Env):
  50. def __init__(self, config=None):
  51. self.observation_space = old_gym.spaces.Box(-1.0, 1.0, (1,))
  52. self.action_space = old_gym.spaces.Discrete(2)
  53. def reset(self):
  54. return self.observation_space.sample()
  55. def step(self, action):
  56. done = True
  57. return self.observation_space.sample(), 1.0, done, {}
  58. def seed(self, seed=None):
  59. pass
  60. def render(self, mode="human"):
  61. pass
  62. class MultiAgentGymnasiumOldAPI(MultiAgentEnv):
  63. def __init__(self, config=None):
  64. super().__init__()
  65. self.observation_space = gym.spaces.Dict(
  66. {"agent0": gym.spaces.Box(-1.0, 1.0, (1,))}
  67. )
  68. self.action_space = gym.spaces.Dict({"agent0": gym.spaces.Discrete(2)})
  69. self._agent_ids = {"agent0"}
  70. def reset(self):
  71. return {"agent0": self.observation_space.sample()}
  72. def step(self, action):
  73. done = True
  74. return (
  75. {"agent0": self.observation_space.sample()},
  76. {"agent0": 1.0},
  77. {"agent0": done, "__all__": done},
  78. {},
  79. )
  80. def seed(self, seed=None):
  81. pass
  82. def render(self, mode="human"):
  83. pass
  84. class TestGymEnvAPIs(unittest.TestCase):
  85. @classmethod
  86. def setUpClass(cls) -> None:
  87. ray.init()
  88. @classmethod
  89. def tearDownClass(cls) -> None:
  90. ray.shutdown()
  91. def test_gymnasium_old_api(self):
  92. """Tests a gymnasium Env that uses the old API."""
  93. def test_():
  94. (
  95. PPOConfig()
  96. .environment(env=GymnasiumOldAPI, auto_wrap_old_gym_envs=False)
  97. # Forces the error to be raised on the local worker so that it is not
  98. # swallowed by a RayActorError and speeds the test up.
  99. .rollouts(num_rollout_workers=0)
  100. .build()
  101. )
  102. self.assertRaisesRegex(
  103. ValueError,
  104. ".*In particular, the `reset\\(\\)` method seems to be faulty..*",
  105. lambda: test_(),
  106. )
  107. def test_gymnasium_old_api_using_auto_wrap(self):
  108. """Tests a gymnasium Env that uses the old API, but is auto-wrapped by RLlib."""
  109. algo = (
  110. PPOConfig()
  111. .environment(env=GymnasiumOldAPI, auto_wrap_old_gym_envs=True)
  112. # Speeds the test up.
  113. .rollouts(num_rollout_workers=0)
  114. .build()
  115. )
  116. algo.train()
  117. algo.stop()
  118. def test_gymnasium_new_api_but_old_spaces(self):
  119. """Tests a gymnasium Env that uses the new API, but has old spaces."""
  120. def test_():
  121. (
  122. PPOConfig()
  123. .environment(GymnasiumNewAPIButOldSpaces, auto_wrap_old_gym_envs=True)
  124. # Forces the error to be raised on the local worker so that it is not
  125. # swallowed by a RayActorError and speeds the test up.
  126. .rollouts(num_rollout_workers=0)
  127. .build()
  128. )
  129. self.assertRaisesRegex(
  130. ValueError,
  131. "Observation space must be a gymnasium.Space!",
  132. lambda: test_(),
  133. )
  134. def test_gymnasium_new_api_but_throws_error_on_reset(self):
  135. """Tests a gymnasium Env that uses the new API, but errors on reset() call."""
  136. def test_():
  137. (
  138. PPOConfig()
  139. .environment(
  140. GymnasiumNewAPIButThrowsErrorOnReset,
  141. auto_wrap_old_gym_envs=True,
  142. )
  143. # Forces the error to be raised on the local worker so that it is not
  144. # swallowed by a RayActorError and speeds the test up.
  145. .rollouts(num_rollout_workers=0)
  146. .build()
  147. )
  148. self.assertRaisesRegex(AssertionError, "kaboom!", lambda: test_())
  149. def test_gymnasium_old_api_but_manually_wrapped(self):
  150. """Tests a gymnasium Env that uses the old API, but is correctly wrapped."""
  151. from gymnasium.wrappers import EnvCompatibility
  152. register_env(
  153. "test",
  154. lambda env_ctx: EnvCompatibility(GymnasiumOldAPI(env_ctx)),
  155. )
  156. algo = (
  157. PPOConfig()
  158. .environment("test", auto_wrap_old_gym_envs=False)
  159. # Speeds the test up.
  160. .rollouts(num_rollout_workers=0)
  161. .build()
  162. )
  163. algo.train()
  164. algo.stop()
  165. def test_old_gym_env(self):
  166. """Tests a old gym.Env (should fail, even with auto-wrapping enabled)."""
  167. def test_():
  168. (
  169. PPOConfig()
  170. .environment(env=OldGymEnv, auto_wrap_old_gym_envs=True)
  171. # Forces the error to be raised on the local worker so that it is not
  172. # swallowed by a RayActorError and speeds the test up.
  173. .rollouts(num_rollout_workers=0)
  174. .build()
  175. )
  176. self.assertRaisesRegex(
  177. ValueError,
  178. "does not abide to the new gymnasium-style API",
  179. lambda: test_(),
  180. )
  181. def test_multi_agent_gymnasium_old_api(self):
  182. """Tests a MultiAgentEnv (gymnasium.Env subclass) that uses the old API."""
  183. def test_():
  184. (
  185. PPOConfig()
  186. .environment(
  187. MultiAgentGymnasiumOldAPI,
  188. auto_wrap_old_gym_envs=False,
  189. )
  190. # Forces the error to be raised on the local worker so that it is not
  191. # swallowed by a RayActorError and speeds the test up.
  192. .rollouts(num_rollout_workers=0)
  193. .build()
  194. )
  195. self.assertRaisesRegex(
  196. ValueError,
  197. ".*In particular, the `reset\\(\\)` method seems to be faulty..*",
  198. lambda: test_(),
  199. )
  200. def test_multi_agent_gymnasium_old_api_auto_wrapped(self):
  201. """Tests a MultiAgentEnv (gymnasium.Env subclass) that uses the old API."""
  202. algo = (
  203. PPOConfig()
  204. .environment(
  205. MultiAgentGymnasiumOldAPI,
  206. auto_wrap_old_gym_envs=True,
  207. disable_env_checking=True,
  208. )
  209. # Speeds the test up.
  210. .rollouts(num_rollout_workers=0)
  211. .build()
  212. )
  213. algo.train()
  214. algo.stop()
  215. def test_multi_agent_gymnasium_old_api_manually_wrapped(self):
  216. """Tests a MultiAgentEnv (gymnasium.Env subclass) that uses the old API."""
  217. register_env(
  218. "test",
  219. lambda env_ctx: MultiAgentEnvCompatibility(
  220. MultiAgentGymnasiumOldAPI(env_ctx)
  221. ),
  222. )
  223. algo = (
  224. PPOConfig()
  225. .environment(
  226. "test", auto_wrap_old_gym_envs=False, disable_env_checking=True
  227. )
  228. # Speeds the test up.
  229. .rollouts(num_rollout_workers=0)
  230. .build()
  231. )
  232. algo.train()
  233. algo.stop()
  234. if __name__ == "__main__":
  235. import pytest
  236. import sys
  237. sys.exit(pytest.main(["-v", __file__]))