external_env.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. from six.moves import queue
  2. import gym
  3. import threading
  4. import uuid
  5. from typing import Callable, Tuple, Optional, TYPE_CHECKING
  6. from ray.rllib.env.base_env import BaseEnv
  7. from ray.rllib.utils.annotations import override, PublicAPI
  8. from ray.rllib.utils.typing import EnvActionType, EnvInfoDict, EnvObsType, \
  9. EnvType, MultiEnvDict
  10. if TYPE_CHECKING:
  11. from ray.rllib.models.preprocessors import Preprocessor
  12. @PublicAPI
  13. class ExternalEnv(threading.Thread):
  14. """An environment that interfaces with external agents.
  15. Unlike simulator envs, control is inverted: The environment queries the
  16. policy to obtain actions and in return logs observations and rewards for
  17. training. This is in contrast to gym.Env, where the algorithm drives the
  18. simulation through env.step() calls.
  19. You can use ExternalEnv as the backend for policy serving (by serving HTTP
  20. requests in the run loop), for ingesting offline logs data (by reading
  21. offline transitions in the run loop), or other custom use cases not easily
  22. expressed through gym.Env.
  23. ExternalEnv supports both on-policy actions (through self.get_action()),
  24. and off-policy actions (through self.log_action()).
  25. This env is thread-safe, but individual episodes must be executed serially.
  26. Examples:
  27. >>> register_env("my_env", lambda config: YourExternalEnv(config))
  28. >>> trainer = DQNTrainer(env="my_env")
  29. >>> while True:
  30. >>> print(trainer.train())
  31. """
  32. @PublicAPI
  33. def __init__(self,
  34. action_space: gym.Space,
  35. observation_space: gym.Space,
  36. max_concurrent: int = 100):
  37. """Initializes an ExternalEnv instance.
  38. Args:
  39. action_space: Action space of the env.
  40. observation_space: Observation space of the env.
  41. max_concurrent: Max number of active episodes to allow at
  42. once. Exceeding this limit raises an error.
  43. """
  44. threading.Thread.__init__(self)
  45. self.daemon = True
  46. self.action_space = action_space
  47. self.observation_space = observation_space
  48. self._episodes = {}
  49. self._finished = set()
  50. self._results_avail_condition = threading.Condition()
  51. self._max_concurrent_episodes = max_concurrent
  52. @PublicAPI
  53. def run(self):
  54. """Override this to implement the run loop.
  55. Your loop should continuously:
  56. 1. Call self.start_episode(episode_id)
  57. 2. Call self.[get|log]_action(episode_id, obs, [action]?)
  58. 3. Call self.log_returns(episode_id, reward)
  59. 4. Call self.end_episode(episode_id, obs)
  60. 5. Wait if nothing to do.
  61. Multiple episodes may be started at the same time.
  62. """
  63. raise NotImplementedError
  64. @PublicAPI
  65. def start_episode(self,
  66. episode_id: Optional[str] = None,
  67. training_enabled: bool = True) -> str:
  68. """Record the start of an episode.
  69. Args:
  70. episode_id: Unique string id for the episode or
  71. None for it to be auto-assigned and returned.
  72. training_enabled: Whether to use experiences for this
  73. episode to improve the policy.
  74. Returns:
  75. Unique string id for the episode.
  76. """
  77. if episode_id is None:
  78. episode_id = uuid.uuid4().hex
  79. if episode_id in self._finished:
  80. raise ValueError(
  81. "Episode {} has already completed.".format(episode_id))
  82. if episode_id in self._episodes:
  83. raise ValueError(
  84. "Episode {} is already started".format(episode_id))
  85. self._episodes[episode_id] = _ExternalEnvEpisode(
  86. episode_id, self._results_avail_condition, training_enabled)
  87. return episode_id
  88. @PublicAPI
  89. def get_action(self, episode_id: str,
  90. observation: EnvObsType) -> EnvActionType:
  91. """Record an observation and get the on-policy action.
  92. Args:
  93. episode_id: Episode id returned from start_episode().
  94. observation: Current environment observation.
  95. Returns:
  96. Action from the env action space.
  97. """
  98. episode = self._get(episode_id)
  99. return episode.wait_for_action(observation)
  100. @PublicAPI
  101. def log_action(self, episode_id: str, observation: EnvObsType,
  102. action: EnvActionType) -> None:
  103. """Record an observation and (off-policy) action taken.
  104. Args:
  105. episode_id: Episode id returned from start_episode().
  106. observation: Current environment observation.
  107. action: Action for the observation.
  108. """
  109. episode = self._get(episode_id)
  110. episode.log_action(observation, action)
  111. @PublicAPI
  112. def log_returns(self,
  113. episode_id: str,
  114. reward: float,
  115. info: Optional[EnvInfoDict] = None) -> None:
  116. """Records returns (rewards and infos) from the environment.
  117. The reward will be attributed to the previous action taken by the
  118. episode. Rewards accumulate until the next action. If no reward is
  119. logged before the next action, a reward of 0.0 is assumed.
  120. Args:
  121. episode_id: Episode id returned from start_episode().
  122. reward: Reward from the environment.
  123. info: Optional info dict.
  124. """
  125. episode = self._get(episode_id)
  126. episode.cur_reward += reward
  127. if info:
  128. episode.cur_info = info or {}
  129. @PublicAPI
  130. def end_episode(self, episode_id: str, observation: EnvObsType) -> None:
  131. """Records the end of an episode.
  132. Args:
  133. episode_id: Episode id returned from start_episode().
  134. observation: Current environment observation.
  135. """
  136. episode = self._get(episode_id)
  137. self._finished.add(episode.episode_id)
  138. episode.done(observation)
  139. def _get(self, episode_id: str) -> "_ExternalEnvEpisode":
  140. """Get a started episode by its ID or raise an error."""
  141. if episode_id in self._finished:
  142. raise ValueError(
  143. "Episode {} has already completed.".format(episode_id))
  144. if episode_id not in self._episodes:
  145. raise ValueError("Episode {} not found.".format(episode_id))
  146. return self._episodes[episode_id]
  147. def to_base_env(
  148. self,
  149. make_env: Optional[Callable[[int], EnvType]] = None,
  150. num_envs: int = 1,
  151. remote_envs: bool = False,
  152. remote_env_batch_wait_ms: int = 0,
  153. ) -> "BaseEnv":
  154. """Converts an RLlib MultiAgentEnv into a BaseEnv object.
  155. The resulting BaseEnv is always vectorized (contains n
  156. sub-environments) to support batched forward passes, where n may
  157. also be 1. BaseEnv also supports async execution via the `poll` and
  158. `send_actions` methods and thus supports external simulators.
  159. Args:
  160. make_env: A callable taking an int as input (which indicates
  161. the number of individual sub-environments within the final
  162. vectorized BaseEnv) and returning one individual
  163. sub-environment.
  164. num_envs: The number of sub-environments to create in the
  165. resulting (vectorized) BaseEnv. The already existing `env`
  166. will be one of the `num_envs`.
  167. remote_envs: Whether each sub-env should be a @ray.remote
  168. actor. You can set this behavior in your config via the
  169. `remote_worker_envs=True` option.
  170. remote_env_batch_wait_ms: The wait time (in ms) to poll remote
  171. sub-environments for, if applicable. Only used if
  172. `remote_envs` is True.
  173. Returns:
  174. The resulting BaseEnv object.
  175. """
  176. if num_envs != 1:
  177. raise ValueError(
  178. "External(MultiAgent)Env does not currently support "
  179. "num_envs > 1. One way of solving this would be to "
  180. "treat your Env as a MultiAgentEnv hosting only one "
  181. "type of agent but with several copies.")
  182. env = ExternalEnvWrapper(self)
  183. return env
  184. class _ExternalEnvEpisode:
  185. """Tracked state for each active episode."""
  186. def __init__(self,
  187. episode_id: str,
  188. results_avail_condition: threading.Condition,
  189. training_enabled: bool,
  190. multiagent: bool = False):
  191. self.episode_id = episode_id
  192. self.results_avail_condition = results_avail_condition
  193. self.training_enabled = training_enabled
  194. self.multiagent = multiagent
  195. self.data_queue = queue.Queue()
  196. self.action_queue = queue.Queue()
  197. if multiagent:
  198. self.new_observation_dict = None
  199. self.new_action_dict = None
  200. self.cur_reward_dict = {}
  201. self.cur_done_dict = {"__all__": False}
  202. self.cur_info_dict = {}
  203. else:
  204. self.new_observation = None
  205. self.new_action = None
  206. self.cur_reward = 0.0
  207. self.cur_done = False
  208. self.cur_info = {}
  209. def get_data(self):
  210. if self.data_queue.empty():
  211. return None
  212. return self.data_queue.get_nowait()
  213. def log_action(self, observation, action):
  214. if self.multiagent:
  215. self.new_observation_dict = observation
  216. self.new_action_dict = action
  217. else:
  218. self.new_observation = observation
  219. self.new_action = action
  220. self._send()
  221. self.action_queue.get(True, timeout=60.0)
  222. def wait_for_action(self, observation):
  223. if self.multiagent:
  224. self.new_observation_dict = observation
  225. else:
  226. self.new_observation = observation
  227. self._send()
  228. return self.action_queue.get(True, timeout=300.0)
  229. def done(self, observation):
  230. if self.multiagent:
  231. self.new_observation_dict = observation
  232. self.cur_done_dict = {"__all__": True}
  233. else:
  234. self.new_observation = observation
  235. self.cur_done = True
  236. self._send()
  237. def _send(self):
  238. if self.multiagent:
  239. if not self.training_enabled:
  240. for agent_id in self.cur_info_dict:
  241. self.cur_info_dict[agent_id]["training_enabled"] = False
  242. item = {
  243. "obs": self.new_observation_dict,
  244. "reward": self.cur_reward_dict,
  245. "done": self.cur_done_dict,
  246. "info": self.cur_info_dict,
  247. }
  248. if self.new_action_dict is not None:
  249. item["off_policy_action"] = self.new_action_dict
  250. self.new_observation_dict = None
  251. self.new_action_dict = None
  252. self.cur_reward_dict = {}
  253. else:
  254. item = {
  255. "obs": self.new_observation,
  256. "reward": self.cur_reward,
  257. "done": self.cur_done,
  258. "info": self.cur_info,
  259. }
  260. if self.new_action is not None:
  261. item["off_policy_action"] = self.new_action
  262. self.new_observation = None
  263. self.new_action = None
  264. self.cur_reward = 0.0
  265. if not self.training_enabled:
  266. item["info"]["training_enabled"] = False
  267. with self.results_avail_condition:
  268. self.data_queue.put_nowait(item)
  269. self.results_avail_condition.notify()
  270. class ExternalEnvWrapper(BaseEnv):
  271. """Internal adapter of ExternalEnv to BaseEnv."""
  272. def __init__(self,
  273. external_env: "ExternalEnv",
  274. preprocessor: "Preprocessor" = None):
  275. from ray.rllib.env.external_multi_agent_env import \
  276. ExternalMultiAgentEnv
  277. self.external_env = external_env
  278. self.prep = preprocessor
  279. self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
  280. self._action_space = external_env.action_space
  281. if preprocessor:
  282. self._observation_space = preprocessor.observation_space
  283. else:
  284. self._observation_space = external_env.observation_space
  285. external_env.start()
  286. @override(BaseEnv)
  287. def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
  288. MultiEnvDict, MultiEnvDict]:
  289. with self.external_env._results_avail_condition:
  290. results = self._poll()
  291. while len(results[0]) == 0:
  292. self.external_env._results_avail_condition.wait()
  293. results = self._poll()
  294. if not self.external_env.is_alive():
  295. raise Exception("Serving thread has stopped.")
  296. limit = self.external_env._max_concurrent_episodes
  297. assert len(results[0]) < limit, \
  298. ("Too many concurrent episodes, were some leaked? This "
  299. "ExternalEnv was created with max_concurrent={}".format(limit))
  300. return results
  301. @override(BaseEnv)
  302. def send_actions(self, action_dict: MultiEnvDict) -> None:
  303. from ray.rllib.env.base_env import _DUMMY_AGENT_ID
  304. if self.multiagent:
  305. for env_id, actions in action_dict.items():
  306. self.external_env._episodes[env_id].action_queue.put(actions)
  307. else:
  308. for env_id, action in action_dict.items():
  309. self.external_env._episodes[env_id].action_queue.put(
  310. action[_DUMMY_AGENT_ID])
  311. def _poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
  312. MultiEnvDict, MultiEnvDict]:
  313. from ray.rllib.env.base_env import with_dummy_agent_id
  314. all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
  315. off_policy_actions = {}
  316. for eid, episode in self.external_env._episodes.copy().items():
  317. data = episode.get_data()
  318. cur_done = episode.cur_done_dict[
  319. "__all__"] if self.multiagent else episode.cur_done
  320. if cur_done:
  321. del self.external_env._episodes[eid]
  322. if data:
  323. if self.prep:
  324. all_obs[eid] = self.prep.transform(data["obs"])
  325. else:
  326. all_obs[eid] = data["obs"]
  327. all_rewards[eid] = data["reward"]
  328. all_dones[eid] = data["done"]
  329. all_infos[eid] = data["info"]
  330. if "off_policy_action" in data:
  331. off_policy_actions[eid] = data["off_policy_action"]
  332. if self.multiagent:
  333. # Ensure a consistent set of keys
  334. # rely on all_obs having all possible keys for now.
  335. for eid, eid_dict in all_obs.items():
  336. for agent_id in eid_dict.keys():
  337. def fix(d, zero_val):
  338. if agent_id not in d[eid]:
  339. d[eid][agent_id] = zero_val
  340. fix(all_rewards, 0.0)
  341. fix(all_dones, False)
  342. fix(all_infos, {})
  343. return (all_obs, all_rewards, all_dones, all_infos,
  344. off_policy_actions)
  345. else:
  346. return with_dummy_agent_id(all_obs), \
  347. with_dummy_agent_id(all_rewards), \
  348. with_dummy_agent_id(all_dones, "__all__"), \
  349. with_dummy_agent_id(all_infos), \
  350. with_dummy_agent_id(off_policy_actions)
  351. @property
  352. @override(BaseEnv)
  353. @PublicAPI
  354. def observation_space(self) -> gym.spaces.Dict:
  355. return self._observation_space
  356. @property
  357. @override(BaseEnv)
  358. @PublicAPI
  359. def action_space(self) -> gym.Space:
  360. return self._action_space