base_env.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769
  1. import logging
  2. from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
  3. Union, Set
  4. import gym
  5. import ray
  6. from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
  7. from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
  8. MultiEnvDict
  9. if TYPE_CHECKING:
  10. from ray.rllib.models.preprocessors import Preprocessor
  11. from ray.rllib.env.external_env import ExternalEnv
  12. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  13. from ray.rllib.env.vector_env import VectorEnv
  14. ASYNC_RESET_RETURN = "async_reset_return"
  15. logger = logging.getLogger(__name__)
  16. @PublicAPI
  17. class BaseEnv:
  18. """The lowest-level env interface used by RLlib for sampling.
  19. BaseEnv models multiple agents executing asynchronously in multiple
  20. vectorized sub-environments. A call to `poll()` returns observations from
  21. ready agents keyed by their sub-environment ID and agent IDs, and
  22. actions for those agents can be sent back via `send_actions()`.
  23. All other RLlib supported env types can be converted to BaseEnv.
  24. RLlib handles these conversions internally in RolloutWorker, for example:
  25. gym.Env => rllib.VectorEnv => rllib.BaseEnv
  26. rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
  27. rllib.ExternalEnv => rllib.BaseEnv
  28. Examples:
  29. >>> env = MyBaseEnv()
  30. >>> obs, rewards, dones, infos, off_policy_actions = env.poll()
  31. >>> print(obs)
  32. {
  33. "env_0": {
  34. "car_0": [2.4, 1.6],
  35. "car_1": [3.4, -3.2],
  36. },
  37. "env_1": {
  38. "car_0": [8.0, 4.1],
  39. },
  40. "env_2": {
  41. "car_0": [2.3, 3.3],
  42. "car_1": [1.4, -0.2],
  43. "car_3": [1.2, 0.1],
  44. },
  45. }
  46. >>> env.send_actions({
  47. ... "env_0": {
  48. ... "car_0": 0,
  49. ... "car_1": 1,
  50. ... }, ...
  51. ... })
  52. >>> obs, rewards, dones, infos, off_policy_actions = env.poll()
  53. >>> print(obs)
  54. {
  55. "env_0": {
  56. "car_0": [4.1, 1.7],
  57. "car_1": [3.2, -4.2],
  58. }, ...
  59. }
  60. >>> print(dones)
  61. {
  62. "env_0": {
  63. "__all__": False,
  64. "car_0": False,
  65. "car_1": True,
  66. }, ...
  67. }
  68. """
  69. def to_base_env(
  70. self,
  71. make_env: Optional[Callable[[int], EnvType]] = None,
  72. num_envs: int = 1,
  73. remote_envs: bool = False,
  74. remote_env_batch_wait_ms: int = 0,
  75. ) -> "BaseEnv":
  76. """Converts an RLlib-supported env into a BaseEnv object.
  77. Supported types for the `env` arg are gym.Env, BaseEnv,
  78. VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv.
  79. The resulting BaseEnv is always vectorized (contains n
  80. sub-environments) to support batched forward passes, where n may also
  81. be 1. BaseEnv also supports async execution via the `poll` and
  82. `send_actions` methods and thus supports external simulators.
  83. TODO: Support gym3 environments, which are already vectorized.
  84. Args:
  85. env: An already existing environment of any supported env type
  86. to convert/wrap into a BaseEnv. Supported types are gym.Env,
  87. BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and
  88. ExternalMultiAgentEnv.
  89. make_env: A callable taking an int as input (which indicates the
  90. number of individual sub-environments within the final
  91. vectorized BaseEnv) and returning one individual
  92. sub-environment.
  93. num_envs: The number of sub-environments to create in the
  94. resulting (vectorized) BaseEnv. The already existing `env`
  95. will be one of the `num_envs`.
  96. remote_envs: Whether each sub-env should be a @ray.remote actor.
  97. You can set this behavior in your config via the
  98. `remote_worker_envs=True` option.
  99. remote_env_batch_wait_ms: The wait time (in ms) to poll remote
  100. sub-environments for, if applicable. Only used if
  101. `remote_envs` is True.
  102. policy_config: Optional policy config dict.
  103. Returns:
  104. The resulting BaseEnv object.
  105. """
  106. del make_env, num_envs, remote_envs, remote_env_batch_wait_ms
  107. return self
  108. @PublicAPI
  109. def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
  110. MultiEnvDict, MultiEnvDict]:
  111. """Returns observations from ready agents.
  112. All return values are two-level dicts mapping from EnvID to dicts
  113. mapping from AgentIDs to (observation/reward/etc..) values.
  114. The number of agents and sub-environments may vary over time.
  115. Returns:
  116. Tuple consisting of
  117. 1) New observations for each ready agent.
  118. 2) Reward values for each ready agent. If the episode is
  119. just started, the value will be None.
  120. 3) Done values for each ready agent. The special key "__all__"
  121. is used to indicate env termination.
  122. 4) Info values for each ready agent.
  123. 5) Agents may take off-policy actions. When that
  124. happens, there will be an entry in this dict that contains the
  125. taken action. There is no need to send_actions() for agents that
  126. have already chosen off-policy actions.
  127. """
  128. raise NotImplementedError
  129. @PublicAPI
  130. def send_actions(self, action_dict: MultiEnvDict) -> None:
  131. """Called to send actions back to running agents in this env.
  132. Actions should be sent for each ready agent that returned observations
  133. in the previous poll() call.
  134. Args:
  135. action_dict: Actions values keyed by env_id and agent_id.
  136. """
  137. raise NotImplementedError
  138. @PublicAPI
  139. def try_reset(self, env_id: Optional[EnvID] = None
  140. ) -> Optional[Union[MultiAgentDict, MultiEnvDict]]:
  141. """Attempt to reset the sub-env with the given id or all sub-envs.
  142. If the environment does not support synchronous reset, None can be
  143. returned here.
  144. Args:
  145. env_id: The sub-environment's ID if applicable. If None, reset
  146. the entire Env (i.e. all sub-environments).
  147. Note: A MultiAgentDict is returned when using the deprecated wrapper
  148. classes such as `ray.rllib.env.base_env._MultiAgentEnvToBaseEnv`,
  149. however for consistency with the poll() method, a `MultiEnvDict` is
  150. returned from the new wrapper classes, such as
  151. `ray.rllib.env.multi_agent_env.MultiAgentEnvWrapper`.
  152. Returns:
  153. The reset (multi-agent) observation dict. None if reset is not
  154. supported.
  155. """
  156. return None
  157. @PublicAPI
  158. def get_sub_environments(
  159. self, as_dict: bool = False) -> Union[List[EnvType], dict]:
  160. """Return a reference to the underlying sub environments, if any.
  161. Args:
  162. as_dict: If True, return a dict mapping from env_id to env.
  163. Returns:
  164. List or dictionary of the underlying sub environments or [] / {}.
  165. """
  166. if as_dict:
  167. return {}
  168. return []
  169. @PublicAPI
  170. def get_agent_ids(self) -> Set[AgentID]:
  171. """Return the agent ids for the sub_environment.
  172. Returns:
  173. All agent ids for each the environment.
  174. """
  175. return {_DUMMY_AGENT_ID}
  176. @PublicAPI
  177. def try_render(self, env_id: Optional[EnvID] = None) -> None:
  178. """Tries to render the sub-environment with the given id or all.
  179. Args:
  180. env_id: The sub-environment's ID, if applicable.
  181. If None, renders the entire Env (i.e. all sub-environments).
  182. """
  183. # By default, do nothing.
  184. pass
  185. @PublicAPI
  186. def stop(self) -> None:
  187. """Releases all resources used."""
  188. # Try calling `close` on all sub-environments.
  189. for env in self.get_sub_environments():
  190. if hasattr(env, "close"):
  191. env.close()
  192. @Deprecated(new="get_sub_environments", error=False)
  193. def get_unwrapped(self) -> List[EnvType]:
  194. return self.get_sub_environments()
  195. @PublicAPI
  196. @property
  197. def observation_space(self) -> gym.Space:
  198. """Returns the observation space for each agent.
  199. Note: samples from the observation space need to be preprocessed into a
  200. `MultiEnvDict` before being used by a policy.
  201. Returns:
  202. The observation space for each environment.
  203. """
  204. raise NotImplementedError
  205. @PublicAPI
  206. @property
  207. def action_space(self) -> gym.Space:
  208. """Returns the action space for each agent.
  209. Note: samples from the action space need to be preprocessed into a
  210. `MultiEnvDict` before being passed to `send_actions`.
  211. Returns:
  212. The observation space for each environment.
  213. """
  214. raise NotImplementedError
  215. @PublicAPI
  216. def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
  217. """Returns a random action for each environment, and potentially each
  218. agent in that environment.
  219. Args:
  220. agent_id: List of agent ids to sample actions for. If None or empty
  221. list, sample actions for all agents in the environment.
  222. Returns:
  223. A random action for each environment.
  224. """
  225. logger.warning("action_space_sample() has not been implemented")
  226. del agent_id
  227. return {}
  228. @PublicAPI
  229. def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
  230. """Returns a random observation for each environment, and potentially
  231. each agent in that environment.
  232. Args:
  233. agent_id: List of agent ids to sample actions for. If None or empty
  234. list, sample actions for all agents in the environment.
  235. Returns:
  236. A random action for each environment.
  237. """
  238. logger.warning("observation_space_sample() has not been implemented")
  239. del agent_id
  240. return {}
  241. @PublicAPI
  242. def last(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
  243. MultiEnvDict, MultiEnvDict]:
  244. """Returns the last observations, rewards, and done flags that were
  245. returned by the environment.
  246. Returns:
  247. The last observations, rewards, and done flags for each environment
  248. """
  249. logger.warning("last has not been implemented for this environment.")
  250. return {}, {}, {}, {}, {}
  251. @PublicAPI
  252. def observation_space_contains(self, x: MultiEnvDict) -> bool:
  253. """Checks if the given observation is valid for each environment.
  254. Args:
  255. x: Observations to check.
  256. Returns:
  257. True if the observations are contained within their respective
  258. spaces. False otherwise.
  259. """
  260. self._space_contains(self.observation_space, x)
  261. @PublicAPI
  262. def action_space_contains(self, x: MultiEnvDict) -> bool:
  263. """Checks if the given actions is valid for each environment.
  264. Args:
  265. x: Actions to check.
  266. Returns:
  267. True if the actions are contained within their respective
  268. spaces. False otherwise.
  269. """
  270. return self._space_contains(self.action_space, x)
  271. def _space_contains(self, space: gym.Space, x: MultiEnvDict) -> bool:
  272. """Check if the given space contains the observations of x.
  273. Args:
  274. space: The space to if x's observations are contained in.
  275. x: The observations to check.
  276. Returns:
  277. True if the observations of x are contained in space.
  278. """
  279. agents = set(self.get_agent_ids())
  280. for multi_agent_dict in x.values():
  281. for agent_id, obs in multi_agent_dict:
  282. if (agent_id not in agents) or (
  283. not space[agent_id].contains(obs)):
  284. return False
  285. return True
  286. # Fixed agent identifier when there is only the single agent in the env
  287. _DUMMY_AGENT_ID = "agent0"
  288. @Deprecated(new="with_dummy_agent_id", error=False)
  289. def _with_dummy_agent_id(env_id_to_values: Dict[EnvID, Any],
  290. dummy_id: "AgentID" = _DUMMY_AGENT_ID
  291. ) -> MultiEnvDict:
  292. return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()}
  293. def with_dummy_agent_id(env_id_to_values: Dict[EnvID, Any],
  294. dummy_id: "AgentID" = _DUMMY_AGENT_ID) -> MultiEnvDict:
  295. return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()}
  296. @Deprecated(
  297. old="ray.rllib.env.base_env._ExternalEnvToBaseEnv",
  298. new="ray.rllib.env.external.ExternalEnvWrapper",
  299. error=False)
  300. class _ExternalEnvToBaseEnv(BaseEnv):
  301. """Internal adapter of ExternalEnv to BaseEnv."""
  302. def __init__(self,
  303. external_env: "ExternalEnv",
  304. preprocessor: "Preprocessor" = None):
  305. from ray.rllib.env.external_multi_agent_env import \
  306. ExternalMultiAgentEnv
  307. self.external_env = external_env
  308. self.prep = preprocessor
  309. self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
  310. self.action_space = external_env.action_space
  311. if preprocessor:
  312. self.observation_space = preprocessor.observation_space
  313. else:
  314. self.observation_space = external_env.observation_space
  315. external_env.start()
  316. @override(BaseEnv)
  317. def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
  318. MultiEnvDict, MultiEnvDict]:
  319. with self.external_env._results_avail_condition:
  320. results = self._poll()
  321. while len(results[0]) == 0:
  322. self.external_env._results_avail_condition.wait()
  323. results = self._poll()
  324. if not self.external_env.is_alive():
  325. raise Exception("Serving thread has stopped.")
  326. limit = self.external_env._max_concurrent_episodes
  327. assert len(results[0]) < limit, \
  328. ("Too many concurrent episodes, were some leaked? This "
  329. "ExternalEnv was created with max_concurrent={}".format(limit))
  330. return results
  331. @override(BaseEnv)
  332. def send_actions(self, action_dict: MultiEnvDict) -> None:
  333. if self.multiagent:
  334. for env_id, actions in action_dict.items():
  335. self.external_env._episodes[env_id].action_queue.put(actions)
  336. else:
  337. for env_id, action in action_dict.items():
  338. self.external_env._episodes[env_id].action_queue.put(
  339. action[_DUMMY_AGENT_ID])
  340. def _poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
  341. MultiEnvDict, MultiEnvDict]:
  342. all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
  343. off_policy_actions = {}
  344. for eid, episode in self.external_env._episodes.copy().items():
  345. data = episode.get_data()
  346. cur_done = episode.cur_done_dict[
  347. "__all__"] if self.multiagent else episode.cur_done
  348. if cur_done:
  349. del self.external_env._episodes[eid]
  350. if data:
  351. if self.prep:
  352. all_obs[eid] = self.prep.transform(data["obs"])
  353. else:
  354. all_obs[eid] = data["obs"]
  355. all_rewards[eid] = data["reward"]
  356. all_dones[eid] = data["done"]
  357. all_infos[eid] = data["info"]
  358. if "off_policy_action" in data:
  359. off_policy_actions[eid] = data["off_policy_action"]
  360. if self.multiagent:
  361. # Ensure a consistent set of keys
  362. # rely on all_obs having all possible keys for now.
  363. for eid, eid_dict in all_obs.items():
  364. for agent_id in eid_dict.keys():
  365. def fix(d, zero_val):
  366. if agent_id not in d[eid]:
  367. d[eid][agent_id] = zero_val
  368. fix(all_rewards, 0.0)
  369. fix(all_dones, False)
  370. fix(all_infos, {})
  371. return (all_obs, all_rewards, all_dones, all_infos,
  372. off_policy_actions)
  373. else:
  374. return _with_dummy_agent_id(all_obs), \
  375. _with_dummy_agent_id(all_rewards), \
  376. _with_dummy_agent_id(all_dones, "__all__"), \
  377. _with_dummy_agent_id(all_infos), \
  378. _with_dummy_agent_id(off_policy_actions)
  379. @Deprecated(
  380. old="ray.rllib.env.base_env._VectorEnvToBaseEnv",
  381. new="ray.rllib.env.vector_env.VectorEnvWrapper",
  382. error=False)
  383. class _VectorEnvToBaseEnv(BaseEnv):
  384. """Internal adapter of VectorEnv to BaseEnv.
  385. We assume the caller will always send the full vector of actions in each
  386. call to send_actions(), and that they call reset_at() on all completed
  387. environments before calling send_actions().
  388. """
  389. def __init__(self, vector_env: "VectorEnv"):
  390. self.vector_env = vector_env
  391. self.action_space = vector_env.action_space
  392. self.observation_space = vector_env.observation_space
  393. self.num_envs = vector_env.num_envs
  394. self.new_obs = None # lazily initialized
  395. self.cur_rewards = [None for _ in range(self.num_envs)]
  396. self.cur_dones = [False for _ in range(self.num_envs)]
  397. self.cur_infos = [None for _ in range(self.num_envs)]
  398. @override(BaseEnv)
  399. def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
  400. MultiEnvDict, MultiEnvDict]:
  401. if self.new_obs is None:
  402. self.new_obs = self.vector_env.vector_reset()
  403. new_obs = dict(enumerate(self.new_obs))
  404. rewards = dict(enumerate(self.cur_rewards))
  405. dones = dict(enumerate(self.cur_dones))
  406. infos = dict(enumerate(self.cur_infos))
  407. self.new_obs = []
  408. self.cur_rewards = []
  409. self.cur_dones = []
  410. self.cur_infos = []
  411. return _with_dummy_agent_id(new_obs), \
  412. _with_dummy_agent_id(rewards), \
  413. _with_dummy_agent_id(dones, "__all__"), \
  414. _with_dummy_agent_id(infos), {}
  415. @override(BaseEnv)
  416. def send_actions(self, action_dict: MultiEnvDict) -> None:
  417. action_vector = [None] * self.num_envs
  418. for i in range(self.num_envs):
  419. action_vector[i] = action_dict[i][_DUMMY_AGENT_ID]
  420. self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \
  421. self.vector_env.vector_step(action_vector)
  422. @override(BaseEnv)
  423. def try_reset(self, env_id: Optional[EnvID] = None) -> MultiAgentDict:
  424. assert env_id is None or isinstance(env_id, int)
  425. return {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)}
  426. @override(BaseEnv)
  427. def get_sub_environments(self) -> List[EnvType]:
  428. return self.vector_env.get_sub_environments()
  429. @override(BaseEnv)
  430. def try_render(self, env_id: Optional[EnvID] = None) -> None:
  431. assert env_id is None or isinstance(env_id, int)
  432. return self.vector_env.try_render_at(env_id)
  433. @Deprecated(
  434. old="ray.rllib.env.base_env._MultiAgentEnvToBaseEnv",
  435. new="ray.rllib.env.multi_agent_env.MultiAgentEnvWrapper",
  436. error=False)
  437. class _MultiAgentEnvToBaseEnv(BaseEnv):
  438. """Internal adapter of MultiAgentEnv to BaseEnv.
  439. This also supports vectorization if num_envs > 1.
  440. """
  441. def __init__(self, make_env: Callable[[int], EnvType],
  442. existing_envs: "MultiAgentEnv", num_envs: int):
  443. """Wraps MultiAgentEnv(s) into the BaseEnv API.
  444. Args:
  445. make_env (Callable[[int], EnvType]): Factory that produces a new
  446. MultiAgentEnv intance. Must be defined, if the number of
  447. existing envs is less than num_envs.
  448. existing_envs (List[MultiAgentEnv]): List of already existing
  449. multi-agent envs.
  450. num_envs (int): Desired num multiagent envs to have at the end in
  451. total. This will include the given (already created)
  452. `existing_envs`.
  453. """
  454. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  455. self.make_env = make_env
  456. self.envs = existing_envs
  457. self.num_envs = num_envs
  458. self.dones = set()
  459. while len(self.envs) < self.num_envs:
  460. self.envs.append(self.make_env(len(self.envs)))
  461. for env in self.envs:
  462. assert isinstance(env, MultiAgentEnv)
  463. self.env_states = [_MultiAgentEnvState(env) for env in self.envs]
  464. @override(BaseEnv)
  465. def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
  466. MultiEnvDict, MultiEnvDict]:
  467. obs, rewards, dones, infos = {}, {}, {}, {}
  468. for i, env_state in enumerate(self.env_states):
  469. obs[i], rewards[i], dones[i], infos[i] = env_state.poll()
  470. return obs, rewards, dones, infos, {}
  471. @override(BaseEnv)
  472. def send_actions(self, action_dict: MultiEnvDict) -> None:
  473. for env_id, agent_dict in action_dict.items():
  474. if env_id in self.dones:
  475. raise ValueError("Env {} is already done".format(env_id))
  476. env = self.envs[env_id]
  477. obs, rewards, dones, infos = env.step(agent_dict)
  478. assert isinstance(obs, dict), "Not a multi-agent obs"
  479. assert isinstance(rewards, dict), "Not a multi-agent reward"
  480. assert isinstance(dones, dict), "Not a multi-agent return"
  481. assert isinstance(infos, dict), "Not a multi-agent info"
  482. # Allow `__common__` entry in `infos` for data unrelated with any
  483. # agent, but rather with the environment itself.
  484. if set(infos).difference(set(obs) | {"__common__"}):
  485. raise ValueError("Key set for infos must be a subset of obs: "
  486. "{} vs {}".format(infos.keys(), obs.keys()))
  487. if "__all__" not in dones:
  488. raise ValueError(
  489. "In multi-agent environments, '__all__': True|False must "
  490. "be included in the 'done' dict: got {}.".format(dones))
  491. if dones["__all__"]:
  492. self.dones.add(env_id)
  493. self.env_states[env_id].observe(obs, rewards, dones, infos)
  494. @override(BaseEnv)
  495. def try_reset(self,
  496. env_id: Optional[EnvID] = None) -> Optional[MultiAgentDict]:
  497. obs = self.env_states[env_id].reset()
  498. assert isinstance(obs, dict), "Not a multi-agent obs"
  499. if obs is not None and env_id in self.dones:
  500. self.dones.remove(env_id)
  501. return obs
  502. @override(BaseEnv)
  503. def get_sub_environments(self) -> List[EnvType]:
  504. return [state.env for state in self.env_states]
  505. @override(BaseEnv)
  506. def try_render(self, env_id: Optional[EnvID] = None) -> None:
  507. if env_id is None:
  508. env_id = 0
  509. assert isinstance(env_id, int)
  510. return self.envs[env_id].render()
  511. @Deprecated(
  512. old="ray.rllib.env.base_env._MultiAgentEnvState",
  513. new="ray.rllib.env.multi_agent_env._MultiAgentEnvState",
  514. error=False)
  515. class _MultiAgentEnvState:
  516. def __init__(self, env: "MultiAgentEnv"):
  517. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  518. assert isinstance(env, MultiAgentEnv)
  519. self.env = env
  520. self.initialized = False
  521. def poll(
  522. self
  523. ) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]:
  524. if not self.initialized:
  525. self.reset()
  526. self.initialized = True
  527. observations = self.last_obs
  528. rewards = {}
  529. dones = {"__all__": self.last_dones["__all__"]}
  530. infos = {"__common__": self.last_infos.get("__common__")}
  531. # If episode is done, release everything we have.
  532. if dones["__all__"]:
  533. rewards = self.last_rewards
  534. self.last_rewards = {}
  535. dones = self.last_dones
  536. self.last_dones = {}
  537. self.last_obs = {}
  538. infos = self.last_infos
  539. self.last_infos = {}
  540. # Only release those agents' rewards/dones/infos, whose
  541. # observations we have.
  542. else:
  543. for ag in observations.keys():
  544. if ag in self.last_rewards:
  545. rewards[ag] = self.last_rewards[ag]
  546. del self.last_rewards[ag]
  547. if ag in self.last_dones:
  548. dones[ag] = self.last_dones[ag]
  549. del self.last_dones[ag]
  550. if ag in self.last_infos:
  551. infos[ag] = self.last_infos[ag]
  552. del self.last_infos[ag]
  553. self.last_dones["__all__"] = False
  554. return observations, rewards, dones, infos
  555. def observe(self, obs: MultiAgentDict, rewards: MultiAgentDict,
  556. dones: MultiAgentDict, infos: MultiAgentDict):
  557. self.last_obs = obs
  558. for ag, r in rewards.items():
  559. if ag in self.last_rewards:
  560. self.last_rewards[ag] += r
  561. else:
  562. self.last_rewards[ag] = r
  563. for ag, d in dones.items():
  564. if ag in self.last_dones:
  565. self.last_dones[ag] = self.last_dones[ag] or d
  566. else:
  567. self.last_dones[ag] = d
  568. self.last_infos = infos
  569. def reset(self) -> MultiAgentDict:
  570. self.last_obs = self.env.reset()
  571. self.last_rewards = {}
  572. self.last_dones = {"__all__": False}
  573. self.last_infos = {"__common__": {}}
  574. return self.last_obs
  575. def convert_to_base_env(
  576. env: EnvType,
  577. make_env: Callable[[int], EnvType] = None,
  578. num_envs: int = 1,
  579. remote_envs: bool = False,
  580. remote_env_batch_wait_ms: int = 0,
  581. ) -> "BaseEnv":
  582. """Converts an RLlib-supported env into a BaseEnv object.
  583. Supported types for the `env` arg are gym.Env, BaseEnv,
  584. VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv.
  585. The resulting BaseEnv is always vectorized (contains n
  586. sub-environments) to support batched forward passes, where n may also
  587. be 1. BaseEnv also supports async execution via the `poll` and
  588. `send_actions` methods and thus supports external simulators.
  589. TODO: Support gym3 environments, which are already vectorized.
  590. Args:
  591. env: An already existing environment of any supported env type
  592. to convert/wrap into a BaseEnv. Supported types are gym.Env,
  593. BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and
  594. ExternalMultiAgentEnv.
  595. make_env: A callable taking an int as input (which indicates the
  596. number of individual sub-environments within the final
  597. vectorized BaseEnv) and returning one individual
  598. sub-environment.
  599. num_envs: The number of sub-environments to create in the
  600. resulting (vectorized) BaseEnv. The already existing `env`
  601. will be one of the `num_envs`.
  602. remote_envs: Whether each sub-env should be a @ray.remote actor.
  603. You can set this behavior in your config via the
  604. `remote_worker_envs=True` option.
  605. remote_env_batch_wait_ms: The wait time (in ms) to poll remote
  606. sub-environments for, if applicable. Only used if
  607. `remote_envs` is True.
  608. Returns:
  609. The resulting BaseEnv object.
  610. """
  611. from ray.rllib.env.remote_base_env import RemoteBaseEnv
  612. from ray.rllib.env.external_env import ExternalEnv
  613. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  614. from ray.rllib.env.vector_env import VectorEnv, VectorEnvWrapper
  615. if remote_envs and num_envs == 1:
  616. raise ValueError("Remote envs only make sense to use if num_envs > 1 "
  617. "(i.e. vectorization is enabled).")
  618. # Given `env` is already a BaseEnv -> Return as is.
  619. if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)):
  620. return env.to_base_env(
  621. make_env=make_env,
  622. num_envs=num_envs,
  623. remote_envs=remote_envs,
  624. remote_env_batch_wait_ms=remote_env_batch_wait_ms,
  625. )
  626. # `env` is not a BaseEnv yet -> Need to convert/vectorize.
  627. else:
  628. # Sub-environments are ray.remote actors:
  629. if remote_envs:
  630. # Determine, whether the already existing sub-env (could
  631. # be a ray.actor) is multi-agent or not.
  632. multiagent = ray.get(env._is_multi_agent.remote()) if \
  633. hasattr(env, "_is_multi_agent") else False
  634. env = RemoteBaseEnv(
  635. make_env,
  636. num_envs,
  637. multiagent=multiagent,
  638. remote_env_batch_wait_ms=remote_env_batch_wait_ms,
  639. existing_envs=[env],
  640. )
  641. # Sub-environments are not ray.remote actors.
  642. else:
  643. # Convert gym.Env to VectorEnv ...
  644. env = VectorEnv.vectorize_gym_envs(
  645. make_env=make_env,
  646. existing_envs=[env],
  647. num_envs=num_envs,
  648. action_space=env.action_space,
  649. observation_space=env.observation_space,
  650. )
  651. # ... then the resulting VectorEnv to a BaseEnv.
  652. env = VectorEnvWrapper(env)
  653. # Make sure conversion went well.
  654. assert isinstance(env, BaseEnv), env
  655. return env