test_check_env.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. import logging
  2. import unittest
  3. from unittest.mock import MagicMock, Mock
  4. import gymnasium as gym
  5. import numpy as np
  6. import pytest
  7. from gymnasium.spaces import Box, Dict, Discrete
  8. from ray.rllib.env.base_env import convert_to_base_env
  9. from ray.rllib.env.multi_agent_env import MultiAgentEnvWrapper, make_multi_agent
  10. from ray.rllib.examples.env.parametric_actions_cartpole import ParametricActionsCartPole
  11. from ray.rllib.examples.env.random_env import RandomEnv
  12. from ray.rllib.utils.pre_checks.env import (
  13. check_base_env,
  14. check_env,
  15. check_gym_environments,
  16. check_multiagent_environments,
  17. )
  18. class TestGymCheckEnv(unittest.TestCase):
  19. @pytest.fixture(autouse=True)
  20. def inject_fixtures(self, caplog):
  21. caplog.set_level(logging.CRITICAL)
  22. def test_has_observation_and_action_space(self):
  23. env = Mock(spec=[])
  24. with pytest.raises(AttributeError, match="Env must have observation_space."):
  25. check_gym_environments(env, Mock())
  26. env = Mock(spec=["observation_space"])
  27. with pytest.raises(AttributeError, match="Env must have action_space."):
  28. check_gym_environments(env, Mock())
  29. def test_obs_and_action_spaces_are_gym_spaces(self):
  30. env = RandomEnv()
  31. observation_space = env.observation_space
  32. env.observation_space = "not a gym space"
  33. with pytest.raises(ValueError, match="Observation space must be a gym.space"):
  34. check_env(env)
  35. env.observation_space = observation_space
  36. env.action_space = "not an action space"
  37. with pytest.raises(ValueError, match="Action space must be a gym.space"):
  38. check_env(env)
  39. def test_reset(self):
  40. reset = MagicMock(return_value=5)
  41. env = RandomEnv()
  42. env.reset = reset
  43. # Check reset with out of bounds fails.
  44. error = ".*The observation collected from env.reset().*"
  45. with pytest.raises(ValueError, match=error):
  46. check_env(env)
  47. # Check reset with obs of incorrect type fails.
  48. reset = MagicMock(return_value=float(0.1))
  49. env.reset = reset
  50. with pytest.raises(ValueError, match=error):
  51. check_env(env)
  52. # Check reset with complex obs in which one sub-space is incorrect.
  53. env = RandomEnv(
  54. config={
  55. "observation_space": Dict(
  56. {"a": Discrete(4), "b": Box(-1.0, 1.0, (1,))}
  57. ),
  58. }
  59. )
  60. reset = MagicMock(return_value={"a": float(0.1), "b": np.array([0.5])})
  61. error = ".*The observation collected from env.reset.*\\n path: 'a'.*"
  62. env.reset = reset
  63. self.assertRaisesRegex(ValueError, error, lambda: check_env(env))
  64. def test_step(self):
  65. step = MagicMock(return_value=(5, 5, True, {}))
  66. env = RandomEnv()
  67. env.step = step
  68. error = ".*The observation collected from env.step.*"
  69. with pytest.raises(ValueError, match=error):
  70. check_env(env)
  71. # check reset that returns obs of incorrect type fails
  72. step = MagicMock(return_value=(float(0.1), 5, True, {}))
  73. env.step = step
  74. with pytest.raises(ValueError, match=error):
  75. check_env(env)
  76. # check step that returns reward of non float/int fails
  77. step = MagicMock(return_value=(1, "Not a valid reward", True, {}))
  78. env.step = step
  79. error = "Your step function must return a reward that is integer or float."
  80. with pytest.raises(ValueError, match=error):
  81. check_env(env)
  82. # check step that returns a non bool fails
  83. step = MagicMock(return_value=(1, float(5), "not a valid done signal", {}))
  84. env.step = step
  85. error = "Your step function must return a done that is a boolean."
  86. with pytest.raises(ValueError, match=error):
  87. check_env(env)
  88. # check step that returns a non dict fails
  89. step = MagicMock(return_value=(1, float(5), True, "not a valid env info"))
  90. env.step = step
  91. error = "Your step function must return a info that is a dict."
  92. with pytest.raises(ValueError, match=error):
  93. check_env(env)
  94. def test_parametric_actions(self):
  95. env = ParametricActionsCartPole(10)
  96. check_env(env)
  97. class TestCheckMultiAgentEnv(unittest.TestCase):
  98. @pytest.fixture(autouse=True)
  99. def inject_fixtures(self, caplog):
  100. caplog.set_level(logging.CRITICAL)
  101. def test_check_env_not_correct_type_error(self):
  102. env = RandomEnv()
  103. with pytest.raises(ValueError, match="The passed env is not"):
  104. check_multiagent_environments(env)
  105. def test_check_env_reset_incorrect_error(self):
  106. reset = MagicMock(return_value=5)
  107. env = make_multi_agent("CartPole-v1")({"num_agents": 2})
  108. env.reset = reset
  109. with pytest.raises(ValueError, match="The element returned by reset"):
  110. check_env(env)
  111. bad_obs = {
  112. 0: np.array([np.inf, np.inf, np.inf, np.inf]),
  113. 1: np.array([np.inf, np.inf, np.inf, np.inf]),
  114. }
  115. env.reset = lambda *_: bad_obs
  116. with pytest.raises(ValueError, match="The observation collected from env"):
  117. check_env(env)
  118. def test_check_incorrect_space_contains_functions_error(self):
  119. def bad_contains_function(self, x):
  120. raise ValueError("This is a bad contains function")
  121. env = make_multi_agent("CartPole-v1")({"num_agents": 2})
  122. env.observation_space_contains = bad_contains_function
  123. with pytest.raises(
  124. ValueError, match="Your observation_space_contains function has some"
  125. ):
  126. check_env(env)
  127. env = make_multi_agent("CartPole-v1")({"num_agents": 2})
  128. bad_action = {0: 2, 1: 2}
  129. env.action_space_sample = lambda *_: bad_action
  130. with pytest.raises(
  131. ValueError, match="The action collected from action_space_sample"
  132. ):
  133. check_env(env)
  134. env.action_space_contains = bad_contains_function
  135. with pytest.raises(
  136. ValueError, match="Your action_space_contains function has some error"
  137. ):
  138. check_env(env)
  139. def test_check_env_step_incorrect_error(self):
  140. step = MagicMock(return_value=(5, 5, True, {}))
  141. env = make_multi_agent("CartPole-v1")({"num_agents": 2})
  142. sampled_obs, info = env.reset()
  143. env.step = step
  144. with pytest.raises(ValueError, match="The element returned by step"):
  145. check_env(env)
  146. step = MagicMock(return_value=(sampled_obs, {0: "Not a reward"}, {0: True}, {}))
  147. env.step = step
  148. with pytest.raises(ValueError, match="Your step function must return rewards"):
  149. check_env(env)
  150. step = MagicMock(return_value=(sampled_obs, {0: 5}, {0: "Not a bool"}, {}))
  151. env.step = step
  152. with pytest.raises(ValueError, match="Your step function must return dones"):
  153. check_env(env)
  154. step = MagicMock(
  155. return_value=(sampled_obs, {0: 5}, {0: False}, {0: "Not a Dict"})
  156. )
  157. env.step = step
  158. with pytest.raises(ValueError, match="Your step function must return infos"):
  159. check_env(env)
  160. def test_bad_sample_function(self):
  161. env = make_multi_agent("CartPole-v1")({"num_agents": 2})
  162. bad_action = {0: 2, 1: 2}
  163. env.action_space_sample = lambda *_: bad_action
  164. with pytest.raises(
  165. ValueError, match="The action collected from action_space_sample"
  166. ):
  167. check_env(env)
  168. env = make_multi_agent("CartPole-v1")({"num_agents": 2})
  169. bad_obs = {
  170. 0: np.array([np.inf, np.inf, np.inf, np.inf]),
  171. 1: np.array([np.inf, np.inf, np.inf, np.inf]),
  172. }
  173. env.observation_space_sample = lambda *_: bad_obs
  174. with pytest.raises(
  175. ValueError,
  176. match="The observation collected from observation_space_sample",
  177. ):
  178. check_env(env)
  179. class TestCheckBaseEnv:
  180. def _make_base_env(self):
  181. del self
  182. num_envs = 2
  183. sub_envs = [
  184. make_multi_agent("CartPole-v1")({"num_agents": 2}) for _ in range(num_envs)
  185. ]
  186. env = MultiAgentEnvWrapper(None, sub_envs, 2)
  187. return env
  188. def test_check_env_not_correct_type_error(self):
  189. env = RandomEnv()
  190. with pytest.raises(ValueError, match="The passed env is not"):
  191. check_base_env(env)
  192. def test_check_env_reset_incorrect_error(self):
  193. reset = MagicMock(return_value=5)
  194. env = self._make_base_env()
  195. env.try_reset = reset
  196. with pytest.raises(ValueError, match=("MultiEnvDict. Instead, it is of type")):
  197. check_env(env)
  198. obs_with_bad_agent_ids = {
  199. 2: np.array([np.inf, np.inf, np.inf, np.inf]),
  200. 1: np.array([np.inf, np.inf, np.inf, np.inf]),
  201. }
  202. obs_with_bad_env_ids = {"bad_env_id": obs_with_bad_agent_ids}
  203. reset = MagicMock(return_value=obs_with_bad_env_ids)
  204. env.try_reset = reset
  205. with pytest.raises(ValueError, match="has dict keys that don't correspond to"):
  206. check_env(env)
  207. reset = MagicMock(return_value={0: obs_with_bad_agent_ids})
  208. env.try_reset = reset
  209. with pytest.raises(
  210. ValueError,
  211. match="The element returned by "
  212. "try_reset has agent_ids that are"
  213. " not the names of the agents",
  214. ):
  215. check_env(env)
  216. out_of_bounds_obs = {
  217. 0: {
  218. 0: np.array([np.inf, np.inf, np.inf, np.inf]),
  219. 1: np.array([np.inf, np.inf, np.inf, np.inf]),
  220. }
  221. }
  222. env.try_reset = lambda *_: out_of_bounds_obs
  223. with pytest.raises(
  224. ValueError, match="The observation collected from try_reset"
  225. ):
  226. check_env(env)
  227. def test_check_space_contains_functions_errors(self):
  228. def bad_contains_function(self, x):
  229. raise ValueError("This is a bad contains function")
  230. env = self._make_base_env()
  231. env.observation_space_contains = bad_contains_function
  232. with pytest.raises(
  233. ValueError, match="Your observation_space_contains function has some"
  234. ):
  235. check_env(env)
  236. env = self._make_base_env()
  237. env.action_space_contains = bad_contains_function
  238. with pytest.raises(
  239. ValueError, match="Your action_space_contains function has some error"
  240. ):
  241. check_env(env)
  242. def test_bad_sample_function(self):
  243. env = self._make_base_env()
  244. bad_action = {0: {0: 2, 1: 2}}
  245. env.action_space_sample = lambda *_: bad_action
  246. with pytest.raises(
  247. ValueError, match="The action collected from action_space_sample"
  248. ):
  249. check_env(env)
  250. env = self._make_base_env()
  251. bad_obs = {
  252. 0: {
  253. 0: np.array([np.inf, np.inf, np.inf, np.inf]),
  254. 1: np.array([np.inf, np.inf, np.inf, np.inf]),
  255. }
  256. }
  257. env.observation_space_sample = lambda *_: bad_obs
  258. with pytest.raises(
  259. ValueError,
  260. match="The observation collected from observation_space_sample",
  261. ):
  262. check_env(env)
  263. def test_check_env_step_incorrect_error(self):
  264. good_reward = {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}
  265. good_terminated = {0: {0: False, 1: False}, 1: {0: False, 1: False}}
  266. good_info = {0: {0: {}, 1: {}}, 1: {0: {}, 1: {}}}
  267. env = self._make_base_env()
  268. bad_multi_env_dict_obs = {0: 1, 1: {0: np.zeros(4)}}
  269. poll = MagicMock(
  270. return_value=(
  271. bad_multi_env_dict_obs,
  272. good_reward,
  273. good_terminated,
  274. good_info,
  275. {},
  276. )
  277. )
  278. env.poll = poll
  279. with pytest.raises(
  280. ValueError,
  281. match="The element returned by step, "
  282. "next_obs contains values that are not"
  283. " MultiAgentDicts",
  284. ):
  285. check_env(env)
  286. bad_reward = {0: {0: "not_reward", 1: 1}}
  287. good_obs = env.observation_space_sample()
  288. poll = MagicMock(
  289. return_value=(good_obs, bad_reward, good_terminated, good_info, {})
  290. )
  291. env.poll = poll
  292. with pytest.raises(
  293. ValueError, match="Your step function must return rewards that are"
  294. ):
  295. check_env(env)
  296. bad_terminated = {0: {0: "not_terminated", 1: False}}
  297. poll = MagicMock(
  298. return_value=(good_obs, good_reward, bad_terminated, good_info, {})
  299. )
  300. env.poll = poll
  301. with pytest.raises(
  302. ValueError,
  303. match="Your step function must return `terminateds` that are boolean.",
  304. ):
  305. check_env(env)
  306. bad_info = {0: {0: "not_info", 1: {}}}
  307. poll = MagicMock(
  308. return_value=(good_obs, good_reward, good_terminated, bad_info, {})
  309. )
  310. env.poll = poll
  311. with pytest.raises(
  312. ValueError,
  313. match="Your step function must return infos that are a dict.",
  314. ):
  315. check_env(env)
  316. def test_check_correct_env(self):
  317. env = self._make_base_env()
  318. check_env(env)
  319. env = gym.make("CartPole-v1")
  320. env = convert_to_base_env(env)
  321. check_env(env)
  322. if __name__ == "__main__":
  323. pytest.main()