sampler.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229
  1. from abc import abstractmethod, ABCMeta
  2. from collections import defaultdict, namedtuple
  3. import logging
  4. import numpy as np
  5. import queue
  6. import threading
  7. import time
  8. import tree # pip install dm_tree
  9. from typing import Any, Callable, Dict, List, Iterator, Optional, Set, Tuple,\
  10. Type, TYPE_CHECKING, Union
  11. from ray.util.debug import log_once
  12. from ray.rllib.evaluation.collectors.sample_collector import \
  13. SampleCollector
  14. from ray.rllib.evaluation.collectors.simple_list_collector import \
  15. SimpleListCollector
  16. from ray.rllib.evaluation.episode import Episode
  17. from ray.rllib.evaluation.metrics import RolloutMetrics
  18. from ray.rllib.evaluation.sample_batch_builder import \
  19. MultiAgentSampleBatchBuilder
  20. from ray.rllib.env.base_env import BaseEnv, convert_to_base_env, \
  21. ASYNC_RESET_RETURN
  22. from ray.rllib.env.wrappers.atari_wrappers import get_wrapper_by_cls, \
  23. MonitorEnv
  24. from ray.rllib.models.preprocessors import Preprocessor
  25. from ray.rllib.offline import InputReader
  26. from ray.rllib.policy.policy import Policy
  27. from ray.rllib.policy.policy_map import PolicyMap
  28. from ray.rllib.policy.sample_batch import SampleBatch
  29. from ray.rllib.utils.annotations import override, DeveloperAPI
  30. from ray.rllib.utils.debug import summarize
  31. from ray.rllib.utils.deprecation import deprecation_warning
  32. from ray.rllib.utils.filter import Filter
  33. from ray.rllib.utils.numpy import convert_to_numpy
  34. from ray.rllib.utils.spaces.space_utils import clip_action, \
  35. unsquash_action, unbatch
  36. from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
  37. EnvObsType, EnvInfoDict, EnvID, MultiEnvDict, EnvActionType, \
  38. TensorStructType
  39. if TYPE_CHECKING:
  40. from ray.rllib.agents.callbacks import DefaultCallbacks
  41. from ray.rllib.evaluation.observation_function import ObservationFunction
  42. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  43. from ray.rllib.utils import try_import_tf
  44. _, tf, _ = try_import_tf()
  45. from gym.envs.classic_control.rendering import SimpleImageViewer
  46. logger = logging.getLogger(__name__)
  47. PolicyEvalData = namedtuple("PolicyEvalData", [
  48. "env_id", "agent_id", "obs", "info", "rnn_state", "prev_action",
  49. "prev_reward"
  50. ])
  51. # A batch of RNN states with dimensions [state_index, batch, state_object].
  52. StateBatch = List[List[Any]]
  53. class NewEpisodeDefaultDict(defaultdict):
  54. def __missing__(self, env_id):
  55. if self.default_factory is None:
  56. raise KeyError(env_id)
  57. else:
  58. ret = self[env_id] = self.default_factory(env_id)
  59. return ret
  60. class _PerfStats:
  61. """Sampler perf stats that will be included in rollout metrics."""
  62. def __init__(self):
  63. self.iters = 0
  64. self.raw_obs_processing_time = 0.0
  65. self.inference_time = 0.0
  66. self.action_processing_time = 0.0
  67. self.env_wait_time = 0.0
  68. self.env_render_time = 0.0
  69. def get(self):
  70. # Mean multiplicator (1000 = ms -> sec).
  71. factor = 1000 / self.iters
  72. return {
  73. # Raw observation preprocessing.
  74. "mean_raw_obs_processing_ms": self.raw_obs_processing_time *
  75. factor,
  76. # Computing actions through policy.
  77. "mean_inference_ms": self.inference_time * factor,
  78. # Processing actions (to be sent to env, e.g. clipping).
  79. "mean_action_processing_ms": self.action_processing_time * factor,
  80. # Waiting for environment (during poll).
  81. "mean_env_wait_ms": self.env_wait_time * factor,
  82. # Environment rendering (False by default).
  83. "mean_env_render_ms": self.env_render_time * factor,
  84. }
  85. @DeveloperAPI
  86. class SamplerInput(InputReader, metaclass=ABCMeta):
  87. """Reads input experiences from an existing sampler."""
  88. @override(InputReader)
  89. def next(self) -> SampleBatchType:
  90. batches = [self.get_data()]
  91. batches.extend(self.get_extra_batches())
  92. if len(batches) > 1:
  93. return batches[0].concat_samples(batches)
  94. else:
  95. return batches[0]
  96. @abstractmethod
  97. @DeveloperAPI
  98. def get_data(self) -> SampleBatchType:
  99. """Called by `self.next()` to return the next batch of data.
  100. Override this in child classes.
  101. Returns:
  102. The next batch of data.
  103. """
  104. raise NotImplementedError
  105. @abstractmethod
  106. @DeveloperAPI
  107. def get_metrics(self) -> List[RolloutMetrics]:
  108. """Returns list of episode metrics since the last call to this method.
  109. The list will contain one RolloutMetrics object per completed episode.
  110. Returns:
  111. List of RolloutMetrics objects, one per completed episode since
  112. the last call to this method.
  113. """
  114. raise NotImplementedError
  115. @abstractmethod
  116. @DeveloperAPI
  117. def get_extra_batches(self) -> List[SampleBatchType]:
  118. """Returns list of extra batches since the last call to this method.
  119. The list will contain all SampleBatches or
  120. MultiAgentBatches that the user has provided thus-far. Users can
  121. add these "extra batches" to an episode by calling the episode's
  122. `add_extra_batch([SampleBatchType])` method. This can be done from
  123. inside an overridden `Policy.compute_actions_from_input_dict(...,
  124. episodes)` or from a custom callback's `on_episode_[start|step|end]()`
  125. methods.
  126. Returns:
  127. List of SamplesBatches or MultiAgentBatches provided thus-far by
  128. the user since the last call to this method.
  129. """
  130. raise NotImplementedError
  131. @DeveloperAPI
  132. class SyncSampler(SamplerInput):
  133. """Sync SamplerInput that collects experiences when `get_data()` is called.
  134. """
  135. def __init__(
  136. self,
  137. *,
  138. worker: "RolloutWorker",
  139. env: BaseEnv,
  140. clip_rewards: Union[bool, float],
  141. rollout_fragment_length: int,
  142. count_steps_by: str = "env_steps",
  143. callbacks: "DefaultCallbacks",
  144. horizon: int = None,
  145. multiple_episodes_in_batch: bool = False,
  146. normalize_actions: bool = True,
  147. clip_actions: bool = False,
  148. soft_horizon: bool = False,
  149. no_done_at_end: bool = False,
  150. observation_fn: Optional["ObservationFunction"] = None,
  151. sample_collector_class: Optional[Type[SampleCollector]] = None,
  152. render: bool = False,
  153. # Obsolete.
  154. policies=None,
  155. policy_mapping_fn=None,
  156. preprocessors=None,
  157. obs_filters=None,
  158. tf_sess=None,
  159. ):
  160. """Initializes a SyncSampler instance.
  161. Args:
  162. worker: The RolloutWorker that will use this Sampler for sampling.
  163. env: Any Env object. Will be converted into an RLlib BaseEnv.
  164. clip_rewards: True for +/-1.0 clipping,
  165. actual float value for +/- value clipping. False for no
  166. clipping.
  167. rollout_fragment_length: The length of a fragment to collect
  168. before building a SampleBatch from the data and resetting
  169. the SampleBatchBuilder object.
  170. count_steps_by: One of "env_steps" (default) or "agent_steps".
  171. Use "agent_steps", if you want rollout lengths to be counted
  172. by individual agent steps. In a multi-agent env,
  173. a single env_step contains one or more agent_steps, depending
  174. on how many agents are present at any given time in the
  175. ongoing episode.
  176. callbacks: The Callbacks object to use when episode
  177. events happen during rollout.
  178. horizon: Hard-reset the Env after this many timesteps.
  179. multiple_episodes_in_batch: Whether to pack multiple
  180. episodes into each batch. This guarantees batches will be
  181. exactly `rollout_fragment_length` in size.
  182. normalize_actions: Whether to normalize actions to the
  183. action space's bounds.
  184. clip_actions: Whether to clip actions according to the
  185. given action_space's bounds.
  186. soft_horizon: If True, calculate bootstrapped values as if
  187. episode had ended, but don't physically reset the environment
  188. when the horizon is hit.
  189. no_done_at_end: Ignore the done=True at the end of the
  190. episode and instead record done=False.
  191. observation_fn: Optional multi-agent observation func to use for
  192. preprocessing observations.
  193. sample_collector_class: An optional Samplecollector sub-class to
  194. use to collect, store, and retrieve environment-, model-,
  195. and sampler data.
  196. render: Whether to try to render the environment after each step.
  197. """
  198. # All of the following arguments are deprecated. They will instead be
  199. # provided via the passed in `worker` arg, e.g. `worker.policy_map`.
  200. if log_once("deprecated_sync_sampler_args"):
  201. if policies is not None:
  202. deprecation_warning(old="policies")
  203. if policy_mapping_fn is not None:
  204. deprecation_warning(old="policy_mapping_fn")
  205. if preprocessors is not None:
  206. deprecation_warning(old="preprocessors")
  207. if obs_filters is not None:
  208. deprecation_warning(old="obs_filters")
  209. if tf_sess is not None:
  210. deprecation_warning(old="tf_sess")
  211. self.base_env = convert_to_base_env(env)
  212. self.rollout_fragment_length = rollout_fragment_length
  213. self.horizon = horizon
  214. self.extra_batches = queue.Queue()
  215. self.perf_stats = _PerfStats()
  216. if not sample_collector_class:
  217. sample_collector_class = SimpleListCollector
  218. self.sample_collector = sample_collector_class(
  219. worker.policy_map,
  220. clip_rewards,
  221. callbacks,
  222. multiple_episodes_in_batch,
  223. rollout_fragment_length,
  224. count_steps_by=count_steps_by)
  225. self.render = render
  226. # Create the rollout generator to use for calls to `get_data()`.
  227. self._env_runner = _env_runner(
  228. worker, self.base_env, self.extra_batches.put, self.horizon,
  229. normalize_actions, clip_actions, multiple_episodes_in_batch,
  230. callbacks, self.perf_stats, soft_horizon, no_done_at_end,
  231. observation_fn, self.sample_collector, self.render)
  232. self.metrics_queue = queue.Queue()
  233. @override(SamplerInput)
  234. def get_data(self) -> SampleBatchType:
  235. while True:
  236. item = next(self._env_runner)
  237. if isinstance(item, RolloutMetrics):
  238. self.metrics_queue.put(item)
  239. else:
  240. return item
  241. @override(SamplerInput)
  242. def get_metrics(self) -> List[RolloutMetrics]:
  243. completed = []
  244. while True:
  245. try:
  246. completed.append(self.metrics_queue.get_nowait()._replace(
  247. perf_stats=self.perf_stats.get()))
  248. except queue.Empty:
  249. break
  250. return completed
  251. @override(SamplerInput)
  252. def get_extra_batches(self) -> List[SampleBatchType]:
  253. extra = []
  254. while True:
  255. try:
  256. extra.append(self.extra_batches.get_nowait())
  257. except queue.Empty:
  258. break
  259. return extra
  260. @DeveloperAPI
  261. class AsyncSampler(threading.Thread, SamplerInput):
  262. """Async SamplerInput that collects experiences in thread and queues them.
  263. Once started, experiences are continuously collected in the background
  264. and put into a Queue, from where they can be unqueued by the caller
  265. of `get_data()`.
  266. """
  267. def __init__(
  268. self,
  269. *,
  270. worker: "RolloutWorker",
  271. env: BaseEnv,
  272. clip_rewards: Union[bool, float],
  273. rollout_fragment_length: int,
  274. count_steps_by: str = "env_steps",
  275. callbacks: "DefaultCallbacks",
  276. horizon: Optional[int] = None,
  277. multiple_episodes_in_batch: bool = False,
  278. normalize_actions: bool = True,
  279. clip_actions: bool = False,
  280. soft_horizon: bool = False,
  281. no_done_at_end: bool = False,
  282. observation_fn: Optional["ObservationFunction"] = None,
  283. sample_collector_class: Optional[Type[SampleCollector]] = None,
  284. render: bool = False,
  285. blackhole_outputs: bool = False,
  286. # Obsolete.
  287. policies=None,
  288. policy_mapping_fn=None,
  289. preprocessors=None,
  290. obs_filters=None,
  291. tf_sess=None,
  292. ):
  293. """Initializes an AsyncSampler instance.
  294. Args:
  295. worker: The RolloutWorker that will use this Sampler for sampling.
  296. env: Any Env object. Will be converted into an RLlib BaseEnv.
  297. clip_rewards: True for +/-1.0 clipping,
  298. actual float value for +/- value clipping. False for no
  299. clipping.
  300. rollout_fragment_length: The length of a fragment to collect
  301. before building a SampleBatch from the data and resetting
  302. the SampleBatchBuilder object.
  303. count_steps_by: One of "env_steps" (default) or "agent_steps".
  304. Use "agent_steps", if you want rollout lengths to be counted
  305. by individual agent steps. In a multi-agent env,
  306. a single env_step contains one or more agent_steps, depending
  307. on how many agents are present at any given time in the
  308. ongoing episode.
  309. horizon: Hard-reset the Env after this many timesteps.
  310. multiple_episodes_in_batch: Whether to pack multiple
  311. episodes into each batch. This guarantees batches will be
  312. exactly `rollout_fragment_length` in size.
  313. normalize_actions: Whether to normalize actions to the
  314. action space's bounds.
  315. clip_actions: Whether to clip actions according to the
  316. given action_space's bounds.
  317. blackhole_outputs: Whether to collect samples, but then
  318. not further process or store them (throw away all samples).
  319. soft_horizon: If True, calculate bootstrapped values as if
  320. episode had ended, but don't physically reset the environment
  321. when the horizon is hit.
  322. no_done_at_end: Ignore the done=True at the end of the
  323. episode and instead record done=False.
  324. observation_fn: Optional multi-agent observation func to use for
  325. preprocessing observations.
  326. sample_collector_class: An optional SampleCollector sub-class to
  327. use to collect, store, and retrieve environment-, model-,
  328. and sampler data.
  329. render: Whether to try to render the environment after each step.
  330. """
  331. # All of the following arguments are deprecated. They will instead be
  332. # provided via the passed in `worker` arg, e.g. `worker.policy_map`.
  333. if log_once("deprecated_async_sampler_args"):
  334. if policies is not None:
  335. deprecation_warning(old="policies")
  336. if policy_mapping_fn is not None:
  337. deprecation_warning(old="policy_mapping_fn")
  338. if preprocessors is not None:
  339. deprecation_warning(old="preprocessors")
  340. if obs_filters is not None:
  341. deprecation_warning(old="obs_filters")
  342. if tf_sess is not None:
  343. deprecation_warning(old="tf_sess")
  344. self.worker = worker
  345. for _, f in worker.filters.items():
  346. assert getattr(f, "is_concurrent", False), \
  347. "Observation Filter must support concurrent updates."
  348. self.base_env = convert_to_base_env(env)
  349. threading.Thread.__init__(self)
  350. self.queue = queue.Queue(5)
  351. self.extra_batches = queue.Queue()
  352. self.metrics_queue = queue.Queue()
  353. self.rollout_fragment_length = rollout_fragment_length
  354. self.horizon = horizon
  355. self.clip_rewards = clip_rewards
  356. self.daemon = True
  357. self.multiple_episodes_in_batch = multiple_episodes_in_batch
  358. self.callbacks = callbacks
  359. self.normalize_actions = normalize_actions
  360. self.clip_actions = clip_actions
  361. self.blackhole_outputs = blackhole_outputs
  362. self.soft_horizon = soft_horizon
  363. self.no_done_at_end = no_done_at_end
  364. self.perf_stats = _PerfStats()
  365. self.shutdown = False
  366. self.observation_fn = observation_fn
  367. self.render = render
  368. if not sample_collector_class:
  369. sample_collector_class = SimpleListCollector
  370. self.sample_collector = sample_collector_class(
  371. self.worker.policy_map,
  372. self.clip_rewards,
  373. self.callbacks,
  374. self.multiple_episodes_in_batch,
  375. self.rollout_fragment_length,
  376. count_steps_by=count_steps_by)
  377. @override(threading.Thread)
  378. def run(self):
  379. try:
  380. self._run()
  381. except BaseException as e:
  382. self.queue.put(e)
  383. raise e
  384. def _run(self):
  385. if self.blackhole_outputs:
  386. queue_putter = (lambda x: None)
  387. extra_batches_putter = (lambda x: None)
  388. else:
  389. queue_putter = self.queue.put
  390. extra_batches_putter = (
  391. lambda x: self.extra_batches.put(x, timeout=600.0))
  392. env_runner = _env_runner(
  393. self.worker, self.base_env, extra_batches_putter, self.horizon,
  394. self.normalize_actions, self.clip_actions,
  395. self.multiple_episodes_in_batch, self.callbacks, self.perf_stats,
  396. self.soft_horizon, self.no_done_at_end, self.observation_fn,
  397. self.sample_collector, self.render)
  398. while not self.shutdown:
  399. # The timeout variable exists because apparently, if one worker
  400. # dies, the other workers won't die with it, unless the timeout is
  401. # set to some large number. This is an empirical observation.
  402. item = next(env_runner)
  403. if isinstance(item, RolloutMetrics):
  404. self.metrics_queue.put(item)
  405. else:
  406. queue_putter(item)
  407. @override(SamplerInput)
  408. def get_data(self) -> SampleBatchType:
  409. if not self.is_alive():
  410. raise RuntimeError("Sampling thread has died")
  411. rollout = self.queue.get(timeout=600.0)
  412. # Propagate errors.
  413. if isinstance(rollout, BaseException):
  414. raise rollout
  415. return rollout
  416. @override(SamplerInput)
  417. def get_metrics(self) -> List[RolloutMetrics]:
  418. completed = []
  419. while True:
  420. try:
  421. completed.append(self.metrics_queue.get_nowait()._replace(
  422. perf_stats=self.perf_stats.get()))
  423. except queue.Empty:
  424. break
  425. return completed
  426. @override(SamplerInput)
  427. def get_extra_batches(self) -> List[SampleBatchType]:
  428. extra = []
  429. while True:
  430. try:
  431. extra.append(self.extra_batches.get_nowait())
  432. except queue.Empty:
  433. break
  434. return extra
  435. def _env_runner(
  436. worker: "RolloutWorker",
  437. base_env: BaseEnv,
  438. extra_batch_callback: Callable[[SampleBatchType], None],
  439. horizon: Optional[int],
  440. normalize_actions: bool,
  441. clip_actions: bool,
  442. multiple_episodes_in_batch: bool,
  443. callbacks: "DefaultCallbacks",
  444. perf_stats: _PerfStats,
  445. soft_horizon: bool,
  446. no_done_at_end: bool,
  447. observation_fn: "ObservationFunction",
  448. sample_collector: Optional[SampleCollector] = None,
  449. render: bool = None,
  450. ) -> Iterator[SampleBatchType]:
  451. """This implements the common experience collection logic.
  452. Args:
  453. worker: Reference to the current rollout worker.
  454. base_env: Env implementing BaseEnv.
  455. extra_batch_callback: function to send extra batch data to.
  456. horizon: Horizon of the episode.
  457. multiple_episodes_in_batch: Whether to pack multiple
  458. episodes into each batch. This guarantees batches will be exactly
  459. `rollout_fragment_length` in size.
  460. normalize_actions: Whether to normalize actions to the action
  461. space's bounds.
  462. clip_actions: Whether to clip actions to the space range.
  463. callbacks: User callbacks to run on episode events.
  464. perf_stats: Record perf stats into this object.
  465. soft_horizon: Calculate rewards but don't reset the
  466. environment when the horizon is hit.
  467. no_done_at_end: Ignore the done=True at the end of the episode
  468. and instead record done=False.
  469. observation_fn: Optional multi-agent
  470. observation func to use for preprocessing observations.
  471. sample_collector: An optional
  472. SampleCollector object to use.
  473. render: Whether to try to render the environment after each
  474. step.
  475. Yields:
  476. Object containing state, action, reward, terminal condition,
  477. and other fields as dictated by `policy`.
  478. """
  479. # May be populated with used for image rendering
  480. simple_image_viewer: Optional["SimpleImageViewer"] = None
  481. # Try to get Env's `max_episode_steps` prop. If it doesn't exist, ignore
  482. # error and continue with max_episode_steps=None.
  483. max_episode_steps = None
  484. try:
  485. max_episode_steps = base_env.get_sub_environments()[
  486. 0].spec.max_episode_steps
  487. except Exception:
  488. pass
  489. # Trainer has a given `horizon` setting.
  490. if horizon:
  491. # `horizon` is larger than env's limit.
  492. if max_episode_steps and horizon > max_episode_steps:
  493. # Try to override the env's own max-step setting with our horizon.
  494. # If this won't work, throw an error.
  495. try:
  496. base_env.get_sub_environments()[
  497. 0].spec.max_episode_steps = horizon
  498. base_env.get_sub_environments()[0]._max_episode_steps = horizon
  499. except Exception:
  500. raise ValueError(
  501. "Your `horizon` setting ({}) is larger than the Env's own "
  502. "timestep limit ({}), which seems to be unsettable! Try "
  503. "to increase the Env's built-in limit to be at least as "
  504. "large as your wanted `horizon`.".format(
  505. horizon, max_episode_steps))
  506. # Otherwise, set Trainer's horizon to env's max-steps.
  507. elif max_episode_steps:
  508. horizon = max_episode_steps
  509. logger.debug(
  510. "No episode horizon specified, setting it to Env's limit ({}).".
  511. format(max_episode_steps))
  512. # No horizon/max_episode_steps -> Episodes may be infinitely long.
  513. else:
  514. horizon = float("inf")
  515. logger.debug("No episode horizon specified, assuming inf.")
  516. # Pool of batch builders, which can be shared across episodes to pack
  517. # trajectory data.
  518. batch_builder_pool: List[MultiAgentSampleBatchBuilder] = []
  519. def get_batch_builder():
  520. if batch_builder_pool:
  521. return batch_builder_pool.pop()
  522. else:
  523. return None
  524. def new_episode(env_id):
  525. episode = Episode(
  526. worker.policy_map,
  527. worker.policy_mapping_fn,
  528. get_batch_builder,
  529. extra_batch_callback,
  530. env_id=env_id,
  531. worker=worker,
  532. )
  533. # Call each policy's Exploration.on_episode_start method.
  534. # Note: This may break the exploration (e.g. ParameterNoise) of
  535. # policies in the `policy_map` that have not been recently used
  536. # (and are therefore stashed to disk). However, we certainly do not
  537. # want to loop through all (even stashed) policies here as that
  538. # would counter the purpose of the LRU policy caching.
  539. for p in worker.policy_map.cache.values():
  540. if getattr(p, "exploration", None) is not None:
  541. p.exploration.on_episode_start(
  542. policy=p,
  543. environment=base_env,
  544. episode=episode,
  545. tf_sess=p.get_session())
  546. callbacks.on_episode_start(
  547. worker=worker,
  548. base_env=base_env,
  549. policies=worker.policy_map,
  550. episode=episode,
  551. env_index=env_id,
  552. )
  553. return episode
  554. active_episodes: Dict[EnvID, Episode] = \
  555. NewEpisodeDefaultDict(new_episode)
  556. while True:
  557. perf_stats.iters += 1
  558. t0 = time.time()
  559. # Get observations from all ready agents.
  560. # types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
  561. unfiltered_obs, rewards, dones, infos, off_policy_actions = \
  562. base_env.poll()
  563. perf_stats.env_wait_time += time.time() - t0
  564. if log_once("env_returns"):
  565. logger.info("Raw obs from env: {}".format(
  566. summarize(unfiltered_obs)))
  567. logger.info("Info return from env: {}".format(summarize(infos)))
  568. # Process observations and prepare for policy evaluation.
  569. t1 = time.time()
  570. # types: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]],
  571. # List[Union[RolloutMetrics, SampleBatchType]]
  572. active_envs, to_eval, outputs = \
  573. _process_observations(
  574. worker=worker,
  575. base_env=base_env,
  576. active_episodes=active_episodes,
  577. unfiltered_obs=unfiltered_obs,
  578. rewards=rewards,
  579. dones=dones,
  580. infos=infos,
  581. horizon=horizon,
  582. multiple_episodes_in_batch=multiple_episodes_in_batch,
  583. callbacks=callbacks,
  584. soft_horizon=soft_horizon,
  585. no_done_at_end=no_done_at_end,
  586. observation_fn=observation_fn,
  587. sample_collector=sample_collector,
  588. )
  589. perf_stats.raw_obs_processing_time += time.time() - t1
  590. for o in outputs:
  591. yield o
  592. # Do batched policy eval (accross vectorized envs).
  593. t2 = time.time()
  594. # types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
  595. eval_results = _do_policy_eval(
  596. to_eval=to_eval,
  597. policies=worker.policy_map,
  598. sample_collector=sample_collector,
  599. active_episodes=active_episodes,
  600. )
  601. perf_stats.inference_time += time.time() - t2
  602. # Process results and update episode state.
  603. t3 = time.time()
  604. actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
  605. _process_policy_eval_results(
  606. to_eval=to_eval,
  607. eval_results=eval_results,
  608. active_episodes=active_episodes,
  609. active_envs=active_envs,
  610. off_policy_actions=off_policy_actions,
  611. policies=worker.policy_map,
  612. normalize_actions=normalize_actions,
  613. clip_actions=clip_actions,
  614. )
  615. perf_stats.action_processing_time += time.time() - t3
  616. # Return computed actions to ready envs. We also send to envs that have
  617. # taken off-policy actions; those envs are free to ignore the action.
  618. t4 = time.time()
  619. base_env.send_actions(actions_to_send)
  620. perf_stats.env_wait_time += time.time() - t4
  621. # Try to render the env, if required.
  622. if render:
  623. t5 = time.time()
  624. # Render can either return an RGB image (uint8 [w x h x 3] numpy
  625. # array) or take care of rendering itself (returning True).
  626. rendered = base_env.try_render()
  627. # Rendering returned an image -> Display it in a SimpleImageViewer.
  628. if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
  629. # ImageViewer not defined yet, try to create one.
  630. if simple_image_viewer is None:
  631. try:
  632. from gym.envs.classic_control.rendering import \
  633. SimpleImageViewer
  634. simple_image_viewer = SimpleImageViewer()
  635. except (ImportError, ModuleNotFoundError):
  636. render = False # disable rendering
  637. logger.warning(
  638. "Could not import gym.envs.classic_control."
  639. "rendering! Try `pip install gym[all]`.")
  640. if simple_image_viewer:
  641. simple_image_viewer.imshow(rendered)
  642. elif rendered not in [True, False, None]:
  643. raise ValueError(
  644. "The env's ({base_env}) `try_render()` method returned an"
  645. " unsupported value! Make sure you either return a "
  646. "uint8/w x h x 3 (RGB) image or handle rendering in a "
  647. "window and then return `True`.")
  648. perf_stats.env_render_time += time.time() - t5
  649. def _process_observations(
  650. *,
  651. worker: "RolloutWorker",
  652. base_env: BaseEnv,
  653. active_episodes: Dict[EnvID, Episode],
  654. unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
  655. rewards: Dict[EnvID, Dict[AgentID, float]],
  656. dones: Dict[EnvID, Dict[AgentID, bool]],
  657. infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
  658. horizon: int,
  659. multiple_episodes_in_batch: bool,
  660. callbacks: "DefaultCallbacks",
  661. soft_horizon: bool,
  662. no_done_at_end: bool,
  663. observation_fn: "ObservationFunction",
  664. sample_collector: SampleCollector,
  665. ) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
  666. RolloutMetrics, SampleBatchType]]]:
  667. """Record new data from the environment and prepare for policy evaluation.
  668. Args:
  669. worker: Reference to the current rollout worker.
  670. base_env: Env implementing BaseEnv.
  671. active_episodes: Mapping from
  672. episode ID to currently ongoing Episode object.
  673. unfiltered_obs: Doubly keyed dict of env-ids -> agent ids
  674. -> unfiltered observation tensor, returned by a `BaseEnv.poll()`
  675. call.
  676. rewards: Doubly keyed dict of env-ids -> agent ids ->
  677. rewards tensor, returned by a `BaseEnv.poll()` call.
  678. dones: Doubly keyed dict of env-ids -> agent ids ->
  679. boolean done flags, returned by a `BaseEnv.poll()` call.
  680. infos: Doubly keyed dict of env-ids -> agent ids ->
  681. info dicts, returned by a `BaseEnv.poll()` call.
  682. horizon: Horizon of the episode.
  683. multiple_episodes_in_batch: Whether to pack multiple
  684. episodes into each batch. This guarantees batches will be exactly
  685. `rollout_fragment_length` in size.
  686. callbacks: User callbacks to run on episode events.
  687. soft_horizon: Calculate rewards but don't reset the
  688. environment when the horizon is hit.
  689. no_done_at_end: Ignore the done=True at the end of the episode
  690. and instead record done=False.
  691. observation_fn: Optional multi-agent
  692. observation func to use for preprocessing observations.
  693. sample_collector: The SampleCollector object
  694. used to store and retrieve environment samples.
  695. Returns:
  696. Tuple consisting of 1) active_envs: Set of non-terminated env ids.
  697. 2) to_eval: Map of policy_id to list of agent PolicyEvalData.
  698. 3) outputs: List of metrics and samples to return from the sampler.
  699. """
  700. # Output objects.
  701. active_envs: Set[EnvID] = set()
  702. to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
  703. outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
  704. # For each (vectorized) sub-environment.
  705. # types: EnvID, Dict[AgentID, EnvObsType]
  706. for env_id, all_agents_obs in unfiltered_obs.items():
  707. is_new_episode: bool = env_id not in active_episodes
  708. episode: Episode = active_episodes[env_id]
  709. if not is_new_episode:
  710. sample_collector.episode_step(episode)
  711. episode._add_agent_rewards(rewards[env_id])
  712. # Check episode termination conditions.
  713. if dones[env_id]["__all__"] or episode.length >= horizon:
  714. hit_horizon = (episode.length >= horizon
  715. and not dones[env_id]["__all__"])
  716. all_agents_done = True
  717. atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(
  718. base_env)
  719. if atari_metrics is not None:
  720. for m in atari_metrics:
  721. outputs.append(
  722. m._replace(custom_metrics=episode.custom_metrics))
  723. else:
  724. outputs.append(
  725. RolloutMetrics(episode.length, episode.total_reward,
  726. dict(episode.agent_rewards),
  727. episode.custom_metrics, {},
  728. episode.hist_data, episode.media))
  729. # Check whether we have to create a fake-last observation
  730. # for some agents (the environment is not required to do so if
  731. # dones[__all__]=True).
  732. for ag_id in episode.get_agents():
  733. if not episode.last_done_for(
  734. ag_id) and ag_id not in all_agents_obs:
  735. # Create a fake (all-0s) observation.
  736. obs_sp = worker.policy_map[episode.policy_for(
  737. ag_id)].observation_space
  738. obs_sp = getattr(obs_sp, "original_space", obs_sp)
  739. all_agents_obs[ag_id] = tree.map_structure(
  740. np.zeros_like, obs_sp.sample())
  741. else:
  742. hit_horizon = False
  743. all_agents_done = False
  744. active_envs.add(env_id)
  745. # Custom observation function is applied before preprocessing.
  746. if observation_fn:
  747. all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn(
  748. agent_obs=all_agents_obs,
  749. worker=worker,
  750. base_env=base_env,
  751. policies=worker.policy_map,
  752. episode=episode)
  753. if not isinstance(all_agents_obs, dict):
  754. raise ValueError(
  755. "observe() must return a dict of agent observations")
  756. common_infos = infos[env_id].get("__common__", {})
  757. episode._set_last_info("__common__", common_infos)
  758. # For each agent in the environment.
  759. # types: AgentID, EnvObsType
  760. for agent_id, raw_obs in all_agents_obs.items():
  761. assert agent_id != "__all__"
  762. last_observation: EnvObsType = episode.last_observation_for(
  763. agent_id)
  764. agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
  765. # A new agent (initial obs) is already done -> Skip entirely.
  766. if last_observation is None and agent_done:
  767. continue
  768. policy_id: PolicyID = episode.policy_for(agent_id)
  769. preprocessor = _get_or_raise(worker.preprocessors, policy_id)
  770. prep_obs: EnvObsType = raw_obs
  771. if preprocessor is not None:
  772. prep_obs = preprocessor.transform(raw_obs)
  773. if log_once("prep_obs"):
  774. logger.info("Preprocessed obs: {}".format(
  775. summarize(prep_obs)))
  776. filtered_obs: EnvObsType = _get_or_raise(worker.filters,
  777. policy_id)(prep_obs)
  778. if log_once("filtered_obs"):
  779. logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
  780. episode._set_last_observation(agent_id, filtered_obs)
  781. episode._set_last_raw_obs(agent_id, raw_obs)
  782. episode._set_last_done(agent_id, agent_done)
  783. # Infos from the environment.
  784. agent_infos = infos[env_id].get(agent_id, {})
  785. episode._set_last_info(agent_id, agent_infos)
  786. # Record transition info if applicable.
  787. if last_observation is None:
  788. sample_collector.add_init_obs(episode, agent_id, env_id,
  789. policy_id, episode.length - 1,
  790. filtered_obs)
  791. elif agent_infos is None or agent_infos.get(
  792. "training_enabled", True):
  793. # Add actions, rewards, next-obs to collectors.
  794. values_dict = {
  795. SampleBatch.T: episode.length - 1,
  796. SampleBatch.ENV_ID: env_id,
  797. SampleBatch.AGENT_INDEX: episode._agent_index(agent_id),
  798. # Action (slot 0) taken at timestep t.
  799. SampleBatch.ACTIONS: episode.last_action_for(agent_id),
  800. # Reward received after taking a at timestep t.
  801. SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
  802. # After taking action=a, did we reach terminal?
  803. SampleBatch.DONES: (False
  804. if (no_done_at_end
  805. or (hit_horizon and soft_horizon))
  806. else agent_done),
  807. # Next observation.
  808. SampleBatch.NEXT_OBS: filtered_obs,
  809. }
  810. # Add extra-action-fetches (policy-inference infos) to
  811. # collectors.
  812. pol = worker.policy_map[policy_id]
  813. for key, value in episode.last_extra_action_outs_for(
  814. agent_id).items():
  815. if key in pol.view_requirements:
  816. values_dict[key] = value
  817. # Env infos for this agent.
  818. if "infos" in pol.view_requirements:
  819. values_dict["infos"] = agent_infos
  820. sample_collector.add_action_reward_next_obs(
  821. episode.episode_id, agent_id, env_id, policy_id,
  822. agent_done, values_dict)
  823. if not agent_done:
  824. item = PolicyEvalData(
  825. env_id, agent_id, filtered_obs, agent_infos, None
  826. if last_observation is None else
  827. episode.rnn_state_for(agent_id), None
  828. if last_observation is None else
  829. episode.last_action_for(agent_id), rewards[env_id].get(
  830. agent_id, 0.0))
  831. to_eval[policy_id].append(item)
  832. # Invoke the `on_episode_step` callback after the step is logged
  833. # to the episode.
  834. # Exception: The very first env.poll() call causes the env to get reset
  835. # (no step taken yet, just a single starting observation logged).
  836. # We need to skip this callback in this case.
  837. if episode.length > 0:
  838. callbacks.on_episode_step(
  839. worker=worker,
  840. base_env=base_env,
  841. policies=worker.policy_map,
  842. episode=episode,
  843. env_index=env_id)
  844. # Episode is done for all agents (dones[__all__] == True)
  845. # or we hit the horizon.
  846. if all_agents_done:
  847. is_done = dones[env_id]["__all__"]
  848. check_dones = is_done and not no_done_at_end
  849. # If, we are not allowed to pack the next episode into the same
  850. # SampleBatch (batch_mode=complete_episodes) -> Build the
  851. # MultiAgentBatch from a single episode and add it to "outputs".
  852. # Otherwise, just postprocess and continue collecting across
  853. # episodes.
  854. ma_sample_batch = sample_collector.postprocess_episode(
  855. episode,
  856. is_done=is_done or (hit_horizon and not soft_horizon),
  857. check_dones=check_dones,
  858. build=not multiple_episodes_in_batch)
  859. if ma_sample_batch:
  860. outputs.append(ma_sample_batch)
  861. # Call each (in-memory) policy's Exploration.on_episode_end
  862. # method.
  863. # Note: This may break the exploration (e.g. ParameterNoise) of
  864. # policies in the `policy_map` that have not been recently used
  865. # (and are therefore stashed to disk). However, we certainly do not
  866. # want to loop through all (even stashed) policies here as that
  867. # would counter the purpose of the LRU policy caching.
  868. for p in worker.policy_map.cache.values():
  869. if getattr(p, "exploration", None) is not None:
  870. p.exploration.on_episode_end(
  871. policy=p,
  872. environment=base_env,
  873. episode=episode,
  874. tf_sess=p.get_session())
  875. # Call custom on_episode_end callback.
  876. callbacks.on_episode_end(
  877. worker=worker,
  878. base_env=base_env,
  879. policies=worker.policy_map,
  880. episode=episode,
  881. env_index=env_id,
  882. )
  883. # Horizon hit and we have a soft horizon (no hard env reset).
  884. if hit_horizon and soft_horizon:
  885. episode.soft_reset()
  886. resetted_obs: Dict[EnvID, Dict[AgentID, EnvObsType]] = \
  887. {env_id: all_agents_obs}
  888. else:
  889. del active_episodes[env_id]
  890. resetted_obs: Dict[EnvID, Dict[AgentID, EnvObsType]] = \
  891. base_env.try_reset(
  892. env_id)
  893. # Reset not supported, drop this env from the ready list.
  894. if resetted_obs is None:
  895. if horizon != float("inf"):
  896. raise ValueError(
  897. "Setting episode horizon requires reset() support "
  898. "from the environment.")
  899. # Creates a new episode if this is not async return.
  900. # If reset is async, we will get its result in some future poll.
  901. elif resetted_obs != ASYNC_RESET_RETURN:
  902. new_episode: Episode = active_episodes[env_id]
  903. resetted_obs = resetted_obs[env_id]
  904. if observation_fn:
  905. resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
  906. agent_obs=resetted_obs,
  907. worker=worker,
  908. base_env=base_env,
  909. policies=worker.policy_map,
  910. episode=new_episode)
  911. # types: AgentID, EnvObsType
  912. for agent_id, raw_obs in resetted_obs.items():
  913. policy_id: PolicyID = new_episode.policy_for(agent_id)
  914. preproccessor = _get_or_raise(worker.preprocessors,
  915. policy_id)
  916. prep_obs: EnvObsType = raw_obs
  917. if preproccessor is not None:
  918. prep_obs = preproccessor.transform(raw_obs)
  919. filtered_obs: EnvObsType = _get_or_raise(
  920. worker.filters, policy_id)(prep_obs)
  921. new_episode._set_last_raw_obs(agent_id, raw_obs)
  922. new_episode._set_last_observation(agent_id, filtered_obs)
  923. # Add initial obs to buffer.
  924. sample_collector.add_init_obs(
  925. new_episode, agent_id, env_id, policy_id,
  926. new_episode.length - 1, filtered_obs)
  927. item = PolicyEvalData(
  928. env_id, agent_id, filtered_obs,
  929. episode.last_info_for(agent_id) or {},
  930. episode.rnn_state_for(agent_id), None, 0.0)
  931. to_eval[policy_id].append(item)
  932. # Try to build something.
  933. if multiple_episodes_in_batch:
  934. sample_batches = \
  935. sample_collector.try_build_truncated_episode_multi_agent_batch()
  936. if sample_batches:
  937. outputs.extend(sample_batches)
  938. return active_envs, to_eval, outputs
  939. def _do_policy_eval(
  940. *,
  941. to_eval: Dict[PolicyID, List[PolicyEvalData]],
  942. policies: PolicyMap,
  943. sample_collector: SampleCollector,
  944. active_episodes: Dict[EnvID, Episode],
  945. ) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
  946. """Call compute_actions on collected episode/model data to get next action.
  947. Args:
  948. to_eval: Mapping of policy IDs to lists of PolicyEvalData objects
  949. (items in these lists will be the batch's items for the model
  950. forward pass).
  951. policies: Mapping from policy ID to Policy obj.
  952. sample_collector: The SampleCollector object to use.
  953. active_episodes: Mapping of EnvID to its currently active episode.
  954. Returns:
  955. Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
  956. """
  957. eval_results: Dict[PolicyID, TensorStructType] = {}
  958. if log_once("compute_actions_input"):
  959. logger.info("Inputs to compute_actions():\n\n{}\n".format(
  960. summarize(to_eval)))
  961. for policy_id, eval_data in to_eval.items():
  962. # In case the policyID has been removed from this worker, we need to
  963. # re-assign policy_id and re-lookup the Policy object to use.
  964. try:
  965. policy: Policy = _get_or_raise(policies, policy_id)
  966. except ValueError:
  967. # Important: Get the policy_mapping_fn from the active
  968. # Episode as the policy_mapping_fn from the worker may
  969. # have already been changed (mapping fn stay constant
  970. # within one episode).
  971. episode = active_episodes[eval_data[0].env_id]
  972. policy_id = episode.policy_mapping_fn(
  973. eval_data[0].agent_id, episode, worker=episode.worker)
  974. policy: Policy = _get_or_raise(policies, policy_id)
  975. input_dict = sample_collector.get_inference_input_dict(policy_id)
  976. eval_results[policy_id] = \
  977. policy.compute_actions_from_input_dict(
  978. input_dict,
  979. timestep=policy.global_timestep,
  980. episodes=[active_episodes[t.env_id] for t in eval_data])
  981. if log_once("compute_actions_result"):
  982. logger.info("Outputs of compute_actions():\n\n{}\n".format(
  983. summarize(eval_results)))
  984. return eval_results
  985. def _process_policy_eval_results(
  986. *,
  987. to_eval: Dict[PolicyID, List[PolicyEvalData]],
  988. eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch,
  989. dict]],
  990. active_episodes: Dict[EnvID, Episode],
  991. active_envs: Set[int],
  992. off_policy_actions: MultiEnvDict,
  993. policies: Dict[PolicyID, Policy],
  994. normalize_actions: bool,
  995. clip_actions: bool,
  996. ) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
  997. """Process the output of policy neural network evaluation.
  998. Records policy evaluation results into the given episode objects and
  999. returns replies to send back to agents in the env.
  1000. Args:
  1001. to_eval: Mapping of policy IDs to lists of PolicyEvalData objects.
  1002. eval_results: Mapping of policy IDs to list of
  1003. actions, rnn-out states, extra-action-fetches dicts.
  1004. active_episodes: Mapping from episode ID to currently ongoing
  1005. Episode object.
  1006. active_envs: Set of non-terminated env ids.
  1007. off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
  1008. off-policy-action, returned by a `BaseEnv.poll()` call.
  1009. policies: Mapping from policy ID to Policy.
  1010. normalize_actions: Whether to normalize actions to the action
  1011. space's bounds.
  1012. clip_actions: Whether to clip actions to the action space's bounds.
  1013. Returns:
  1014. Nested dict of env id -> agent id -> actions to be sent to
  1015. Env (np.ndarrays).
  1016. """
  1017. actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
  1018. defaultdict(dict)
  1019. # types: int
  1020. for env_id in active_envs:
  1021. actions_to_send[env_id] = {} # at minimum send empty dict
  1022. # types: PolicyID, List[PolicyEvalData]
  1023. for policy_id, eval_data in to_eval.items():
  1024. actions: TensorStructType = eval_results[policy_id][0]
  1025. actions = convert_to_numpy(actions)
  1026. rnn_out_cols: StateBatch = eval_results[policy_id][1]
  1027. extra_action_out_cols: dict = eval_results[policy_id][2]
  1028. # In case actions is a list (representing the 0th dim of a batch of
  1029. # primitive actions), try converting it first.
  1030. if isinstance(actions, list):
  1031. actions = np.array(actions)
  1032. # Store RNN state ins/outs and extra-action fetches to episode.
  1033. for f_i, column in enumerate(rnn_out_cols):
  1034. extra_action_out_cols["state_out_{}".format(f_i)] = column
  1035. policy: Policy = _get_or_raise(policies, policy_id)
  1036. # Split action-component batches into single action rows.
  1037. actions: List[EnvActionType] = unbatch(actions)
  1038. # types: int, EnvActionType
  1039. for i, action in enumerate(actions):
  1040. # Normalize, if necessary.
  1041. if normalize_actions:
  1042. action_to_send = unsquash_action(action,
  1043. policy.action_space_struct)
  1044. # Clip, if necessary.
  1045. elif clip_actions:
  1046. action_to_send = clip_action(action,
  1047. policy.action_space_struct)
  1048. else:
  1049. action_to_send = action
  1050. env_id: int = eval_data[i].env_id
  1051. agent_id: AgentID = eval_data[i].agent_id
  1052. episode: Episode = active_episodes[env_id]
  1053. episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
  1054. episode._set_last_extra_action_outs(
  1055. agent_id, {k: v[i]
  1056. for k, v in extra_action_out_cols.items()})
  1057. if env_id in off_policy_actions and \
  1058. agent_id in off_policy_actions[env_id]:
  1059. episode._set_last_action(agent_id,
  1060. off_policy_actions[env_id][agent_id])
  1061. else:
  1062. episode._set_last_action(agent_id, action)
  1063. assert agent_id not in actions_to_send[env_id]
  1064. actions_to_send[env_id][agent_id] = action_to_send
  1065. return actions_to_send
  1066. def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]:
  1067. """Atari games have multiple logical episodes, one per life.
  1068. However, for metrics reporting we count full episodes, all lives included.
  1069. """
  1070. sub_environments = base_env.get_sub_environments()
  1071. if not sub_environments:
  1072. return None
  1073. atari_out = []
  1074. for sub_env in sub_environments:
  1075. monitor = get_wrapper_by_cls(sub_env, MonitorEnv)
  1076. if not monitor:
  1077. return None
  1078. for eps_rew, eps_len in monitor.next_episode_results():
  1079. atari_out.append(RolloutMetrics(eps_len, eps_rew))
  1080. return atari_out
  1081. def _to_column_format(rnn_state_rows: List[List[Any]]) -> StateBatch:
  1082. num_cols = len(rnn_state_rows[0])
  1083. return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]
  1084. def _get_or_raise(mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]],
  1085. policy_id: PolicyID) -> Union[Policy, Preprocessor, Filter]:
  1086. """Returns an object under key `policy_id` in `mapping`.
  1087. Args:
  1088. mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
  1089. mapping dict from policy id (str) to actual object (Policy,
  1090. Preprocessor, etc.).
  1091. policy_id (str): The policy ID to lookup.
  1092. Returns:
  1093. Union[Policy, Preprocessor, Filter]: The found object.
  1094. Raises:
  1095. ValueError: If `policy_id` cannot be found in `mapping`.
  1096. """
  1097. if policy_id not in mapping:
  1098. raise ValueError(
  1099. "Could not find policy for agent: PolicyID `{}` not found "
  1100. "in policy map, whose keys are `{}`.".format(
  1101. policy_id, mapping.keys()))
  1102. return mapping[policy_id]