test_external_env.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import gym
  2. import numpy as np
  3. import random
  4. import unittest
  5. import uuid
  6. import ray
  7. from ray.rllib.agents.dqn import DQNTrainer
  8. from ray.rllib.agents.pg import PGTrainer
  9. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  10. from ray.rllib.env.external_env import ExternalEnv
  11. from ray.rllib.evaluation.tests.test_rollout_worker import (BadPolicy,
  12. MockPolicy)
  13. from ray.rllib.examples.env.mock_env import MockEnv
  14. from ray.rllib.utils.test_utils import framework_iterator
  15. from ray.tune.registry import register_env
  16. def make_simple_serving(multiagent, superclass):
  17. class SimpleServing(superclass):
  18. def __init__(self, env):
  19. superclass.__init__(self, env.action_space, env.observation_space)
  20. self.env = env
  21. def run(self):
  22. eid = self.start_episode()
  23. obs = self.env.reset()
  24. while True:
  25. action = self.get_action(eid, obs)
  26. obs, reward, done, info = self.env.step(action)
  27. if multiagent:
  28. self.log_returns(eid, reward)
  29. else:
  30. self.log_returns(eid, reward, info=info)
  31. if done:
  32. self.end_episode(eid, obs)
  33. obs = self.env.reset()
  34. eid = self.start_episode()
  35. return SimpleServing
  36. # generate & register SimpleServing class
  37. SimpleServing = make_simple_serving(False, ExternalEnv)
  38. class PartOffPolicyServing(ExternalEnv):
  39. def __init__(self, env, off_pol_frac):
  40. ExternalEnv.__init__(self, env.action_space, env.observation_space)
  41. self.env = env
  42. self.off_pol_frac = off_pol_frac
  43. def run(self):
  44. eid = self.start_episode()
  45. obs = self.env.reset()
  46. while True:
  47. if random.random() < self.off_pol_frac:
  48. action = self.env.action_space.sample()
  49. self.log_action(eid, obs, action)
  50. else:
  51. action = self.get_action(eid, obs)
  52. obs, reward, done, info = self.env.step(action)
  53. self.log_returns(eid, reward, info=info)
  54. if done:
  55. self.end_episode(eid, obs)
  56. obs = self.env.reset()
  57. eid = self.start_episode()
  58. class SimpleOffPolicyServing(ExternalEnv):
  59. def __init__(self, env, fixed_action):
  60. ExternalEnv.__init__(self, env.action_space, env.observation_space)
  61. self.env = env
  62. self.fixed_action = fixed_action
  63. def run(self):
  64. eid = self.start_episode()
  65. obs = self.env.reset()
  66. while True:
  67. action = self.fixed_action
  68. self.log_action(eid, obs, action)
  69. obs, reward, done, info = self.env.step(action)
  70. self.log_returns(eid, reward, info=info)
  71. if done:
  72. self.end_episode(eid, obs)
  73. obs = self.env.reset()
  74. eid = self.start_episode()
  75. class MultiServing(ExternalEnv):
  76. def __init__(self, env_creator):
  77. self.env_creator = env_creator
  78. self.env = env_creator()
  79. ExternalEnv.__init__(self, self.env.action_space,
  80. self.env.observation_space)
  81. def run(self):
  82. envs = [self.env_creator() for _ in range(5)]
  83. cur_obs = {}
  84. eids = {}
  85. while True:
  86. active = np.random.choice(range(5), 2, replace=False)
  87. for i in active:
  88. if i not in cur_obs:
  89. eids[i] = uuid.uuid4().hex
  90. self.start_episode(episode_id=eids[i])
  91. cur_obs[i] = envs[i].reset()
  92. actions = [self.get_action(eids[i], cur_obs[i]) for i in active]
  93. for i, action in zip(active, actions):
  94. obs, reward, done, _ = envs[i].step(action)
  95. cur_obs[i] = obs
  96. self.log_returns(eids[i], reward)
  97. if done:
  98. self.end_episode(eids[i], obs)
  99. del cur_obs[i]
  100. class TestExternalEnv(unittest.TestCase):
  101. @classmethod
  102. def setUpClass(cls) -> None:
  103. ray.init(ignore_reinit_error=True)
  104. @classmethod
  105. def tearDownClass(cls) -> None:
  106. ray.shutdown()
  107. def test_external_env_complete_episodes(self):
  108. ev = RolloutWorker(
  109. env_creator=lambda _: SimpleServing(MockEnv(25)),
  110. policy_spec=MockPolicy,
  111. rollout_fragment_length=40,
  112. batch_mode="complete_episodes")
  113. for _ in range(3):
  114. batch = ev.sample()
  115. self.assertEqual(batch.count, 50)
  116. def test_external_env_truncate_episodes(self):
  117. ev = RolloutWorker(
  118. env_creator=lambda _: SimpleServing(MockEnv(25)),
  119. policy_spec=MockPolicy,
  120. rollout_fragment_length=40,
  121. batch_mode="truncate_episodes")
  122. for _ in range(3):
  123. batch = ev.sample()
  124. self.assertEqual(batch.count, 40)
  125. def test_external_env_off_policy(self):
  126. ev = RolloutWorker(
  127. env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
  128. policy_spec=MockPolicy,
  129. rollout_fragment_length=40,
  130. batch_mode="complete_episodes")
  131. for _ in range(3):
  132. batch = ev.sample()
  133. self.assertEqual(batch.count, 50)
  134. self.assertEqual(batch["actions"][0], 42)
  135. self.assertEqual(batch["actions"][-1], 42)
  136. def test_external_env_bad_actions(self):
  137. ev = RolloutWorker(
  138. env_creator=lambda _: SimpleServing(MockEnv(25)),
  139. policy_spec=BadPolicy,
  140. sample_async=True,
  141. rollout_fragment_length=40,
  142. batch_mode="truncate_episodes")
  143. self.assertRaises(Exception, lambda: ev.sample())
  144. def test_train_cartpole_off_policy(self):
  145. register_env(
  146. "test3", lambda _: PartOffPolicyServing(
  147. gym.make("CartPole-v0"), off_pol_frac=0.2))
  148. config = {
  149. "num_workers": 0,
  150. "exploration_config": {
  151. "epsilon_timesteps": 100
  152. },
  153. }
  154. for _ in framework_iterator(config, frameworks=("tf", "torch")):
  155. dqn = DQNTrainer(env="test3", config=config)
  156. reached = False
  157. for i in range(50):
  158. result = dqn.train()
  159. print("Iteration {}, reward {}, timesteps {}".format(
  160. i, result["episode_reward_mean"],
  161. result["timesteps_total"]))
  162. if result["episode_reward_mean"] >= 80:
  163. reached = True
  164. break
  165. if not reached:
  166. raise Exception("failed to improve reward")
  167. def test_train_cartpole(self):
  168. register_env("test", lambda _: SimpleServing(gym.make("CartPole-v0")))
  169. config = {"num_workers": 0}
  170. for _ in framework_iterator(config, frameworks=("tf", "torch")):
  171. pg = PGTrainer(env="test", config=config)
  172. reached = False
  173. for i in range(80):
  174. result = pg.train()
  175. print("Iteration {}, reward {}, timesteps {}".format(
  176. i, result["episode_reward_mean"],
  177. result["timesteps_total"]))
  178. if result["episode_reward_mean"] >= 80:
  179. reached = True
  180. break
  181. if not reached:
  182. raise Exception("failed to improve reward")
  183. def test_train_cartpole_multi(self):
  184. register_env("test2",
  185. lambda _: MultiServing(lambda: gym.make("CartPole-v0")))
  186. config = {"num_workers": 0}
  187. for _ in framework_iterator(config, frameworks=("tf", "torch")):
  188. pg = PGTrainer(env="test2", config=config)
  189. reached = False
  190. for i in range(80):
  191. result = pg.train()
  192. print("Iteration {}, reward {}, timesteps {}".format(
  193. i, result["episode_reward_mean"],
  194. result["timesteps_total"]))
  195. if result["episode_reward_mean"] >= 80:
  196. reached = True
  197. break
  198. if not reached:
  199. raise Exception("failed to improve reward")
  200. def test_external_env_horizon_not_supported(self):
  201. ev = RolloutWorker(
  202. env_creator=lambda _: SimpleServing(MockEnv(25)),
  203. policy_spec=MockPolicy,
  204. episode_horizon=20,
  205. rollout_fragment_length=10,
  206. batch_mode="complete_episodes")
  207. self.assertRaises(ValueError, lambda: ev.sample())
  208. if __name__ == "__main__":
  209. import pytest
  210. import sys
  211. sys.exit(pytest.main(["-v", __file__]))