torch_policy_v2.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476
  1. import copy
  2. import functools
  3. import logging
  4. import math
  5. import os
  6. import threading
  7. import time
  8. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, Union
  9. import gymnasium as gym
  10. import numpy as np
  11. import tree # pip install dm_tree
  12. import ray
  13. from ray.rllib.core.models.base import STATE_IN, STATE_OUT
  14. from ray.rllib.core.rl_module import RLModule
  15. from ray.rllib.models.catalog import ModelCatalog
  16. from ray.rllib.models.modelv2 import ModelV2
  17. from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
  18. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  19. from ray.rllib.policy.policy import Policy
  20. from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
  21. from ray.rllib.policy.sample_batch import SampleBatch
  22. from ray.rllib.policy.torch_policy import _directStepOptimizerSingleton
  23. from ray.rllib.utils import NullContextManager, force_list
  24. from ray.rllib.utils.annotations import (
  25. DeveloperAPI,
  26. OverrideToImplementCustomLogic,
  27. OverrideToImplementCustomLogic_CallToSuperRecommended,
  28. is_overridden,
  29. override,
  30. )
  31. from ray.rllib.utils.annotations import ExperimentalAPI
  32. from ray.rllib.utils.error import ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL
  33. from ray.rllib.utils.framework import try_import_torch
  34. from ray.rllib.utils.metrics import (
  35. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
  36. NUM_AGENT_STEPS_TRAINED,
  37. NUM_GRAD_UPDATES_LIFETIME,
  38. )
  39. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  40. from ray.rllib.utils.numpy import convert_to_numpy
  41. from ray.rllib.utils.spaces.space_utils import normalize_action
  42. from ray.rllib.utils.threading import with_lock
  43. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  44. from ray.rllib.utils.typing import (
  45. AlgorithmConfigDict,
  46. GradInfoDict,
  47. ModelGradients,
  48. ModelWeights,
  49. PolicyState,
  50. TensorStructType,
  51. TensorType,
  52. )
  53. if TYPE_CHECKING:
  54. from ray.rllib.evaluation import Episode # noqa
  55. torch, nn = try_import_torch()
  56. logger = logging.getLogger(__name__)
  57. @DeveloperAPI
  58. class TorchPolicyV2(Policy):
  59. """PyTorch specific Policy class to use with RLlib."""
  60. @DeveloperAPI
  61. def __init__(
  62. self,
  63. observation_space: gym.spaces.Space,
  64. action_space: gym.spaces.Space,
  65. config: AlgorithmConfigDict,
  66. *,
  67. max_seq_len: int = 20,
  68. ):
  69. """Initializes a TorchPolicy instance.
  70. Args:
  71. observation_space: Observation space of the policy.
  72. action_space: Action space of the policy.
  73. config: The Policy's config dict.
  74. max_seq_len: Max sequence length for LSTM training.
  75. """
  76. self.framework = config["framework"] = "torch"
  77. self._loss_initialized = False
  78. super().__init__(observation_space, action_space, config)
  79. # Create model.
  80. if self.config.get("_enable_rl_module_api", False):
  81. model = self.make_rl_module()
  82. dist_class = None
  83. else:
  84. model, dist_class = self._init_model_and_dist_class()
  85. # Create multi-GPU model towers, if necessary.
  86. # - The central main model will be stored under self.model, residing
  87. # on self.device (normally, a CPU).
  88. # - Each GPU will have a copy of that model under
  89. # self.model_gpu_towers, matching the devices in self.devices.
  90. # - Parallelization is done by splitting the train batch and passing
  91. # it through the model copies in parallel, then averaging over the
  92. # resulting gradients, applying these averages on the main model and
  93. # updating all towers' weights from the main model.
  94. # - In case of just one device (1 (fake or real) GPU or 1 CPU), no
  95. # parallelization will be done.
  96. # Get devices to build the graph on.
  97. num_gpus = self._get_num_gpus_for_policy()
  98. gpu_ids = list(range(torch.cuda.device_count()))
  99. logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
  100. # Place on one or more CPU(s) when either:
  101. # - Fake GPU mode.
  102. # - num_gpus=0 (either set by user or we are in local_mode=True).
  103. # - No GPUs available.
  104. if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
  105. self.device = torch.device("cpu")
  106. self.devices = [self.device for _ in range(int(math.ceil(num_gpus)) or 1)]
  107. self.model_gpu_towers = [
  108. model if i == 0 else copy.deepcopy(model)
  109. for i in range(int(math.ceil(num_gpus)) or 1)
  110. ]
  111. if hasattr(self, "target_model"):
  112. self.target_models = {
  113. m: self.target_model for m in self.model_gpu_towers
  114. }
  115. self.model = model
  116. # Place on one or more actual GPU(s), when:
  117. # - num_gpus > 0 (set by user) AND
  118. # - local_mode=False AND
  119. # - actual GPUs available AND
  120. # - non-fake GPU mode.
  121. else:
  122. # We are a remote worker (WORKER_MODE=1):
  123. # GPUs should be assigned to us by ray.
  124. if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
  125. gpu_ids = ray.get_gpu_ids()
  126. if len(gpu_ids) < num_gpus:
  127. raise ValueError(
  128. "TorchPolicy was not able to find enough GPU IDs! Found "
  129. f"{gpu_ids}, but num_gpus={num_gpus}."
  130. )
  131. self.devices = [
  132. torch.device("cuda:{}".format(i))
  133. for i, id_ in enumerate(gpu_ids)
  134. if i < num_gpus
  135. ]
  136. self.device = self.devices[0]
  137. ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
  138. self.model_gpu_towers = []
  139. for i, _ in enumerate(ids):
  140. model_copy = copy.deepcopy(model)
  141. self.model_gpu_towers.append(model_copy.to(self.devices[i]))
  142. if hasattr(self, "target_model"):
  143. self.target_models = {
  144. m: copy.deepcopy(self.target_model).to(self.devices[i])
  145. for i, m in enumerate(self.model_gpu_towers)
  146. }
  147. self.model = self.model_gpu_towers[0]
  148. self.dist_class = dist_class
  149. self.unwrapped_model = model # used to support DistributedDataParallel
  150. # Lock used for locking some methods on the object-level.
  151. # This prevents possible race conditions when calling the model
  152. # first, then its value function (e.g. in a loss function), in
  153. # between of which another model call is made (e.g. to compute an
  154. # action).
  155. self._lock = threading.RLock()
  156. self._state_inputs = self.model.get_initial_state()
  157. self._is_recurrent = len(tree.flatten(self._state_inputs)) > 0
  158. if self.config.get("_enable_rl_module_api", False):
  159. # Maybe update view_requirements, e.g. for recurrent case.
  160. self.view_requirements = self.model.update_default_view_requirements(
  161. self.view_requirements
  162. )
  163. else:
  164. # Auto-update model's inference view requirements, if recurrent.
  165. self._update_model_view_requirements_from_init_state()
  166. # Combine view_requirements for Model and Policy.
  167. self.view_requirements.update(self.model.view_requirements)
  168. if self.config.get("_enable_rl_module_api", False):
  169. # We don't need an exploration object with RLModules
  170. self.exploration = None
  171. else:
  172. self.exploration = self._create_exploration()
  173. if not self.config.get("_enable_learner_api", False):
  174. self._optimizers = force_list(self.optimizer())
  175. # Backward compatibility workaround so Policy will call self.loss()
  176. # directly.
  177. # TODO (jungong): clean up after all policies are migrated to new sub-class
  178. # implementation.
  179. self._loss = None
  180. # Store, which params (by index within the model's list of
  181. # parameters) should be updated per optimizer.
  182. # Maps optimizer idx to set or param indices.
  183. self.multi_gpu_param_groups: List[Set[int]] = []
  184. main_params = {p: i for i, p in enumerate(self.model.parameters())}
  185. for o in self._optimizers:
  186. param_indices = []
  187. for pg_idx, pg in enumerate(o.param_groups):
  188. for p in pg["params"]:
  189. param_indices.append(main_params[p])
  190. self.multi_gpu_param_groups.append(set(param_indices))
  191. # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
  192. # one with m towers (num_gpus).
  193. num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
  194. self._loaded_batches = [[] for _ in range(num_buffers)]
  195. # If set, means we are using distributed allreduce during learning.
  196. self.distributed_world_size = None
  197. self.batch_divisibility_req = self.get_batch_divisibility_req()
  198. self.max_seq_len = max_seq_len
  199. # If model is an RLModule it won't have tower_stats instead there will be a
  200. # self.tower_state[model] -> dict for each tower.
  201. self.tower_stats = {}
  202. if not hasattr(self.model, "tower_stats"):
  203. for model in self.model_gpu_towers:
  204. self.tower_stats[model] = {}
  205. def loss_initialized(self):
  206. return self._loss_initialized
  207. @DeveloperAPI
  208. @OverrideToImplementCustomLogic
  209. @override(Policy)
  210. def loss(
  211. self,
  212. model: ModelV2,
  213. dist_class: Type[TorchDistributionWrapper],
  214. train_batch: SampleBatch,
  215. ) -> Union[TensorType, List[TensorType]]:
  216. """Constructs the loss function.
  217. Args:
  218. model: The Model to calculate the loss for.
  219. dist_class: The action distr. class.
  220. train_batch: The training data.
  221. Returns:
  222. Loss tensor given the input batch.
  223. """
  224. # Under the new _enable_learner_api the loss function still gets called in order
  225. # to initialize the view requirements of the sample batches that are returned by
  226. # the sampler. In this case, we don't actually want to compute any loss, however
  227. # if we access the keys that are needed for a forward_train pass, then the
  228. # sampler will include those keys in the sample batches it returns. This means
  229. # that the correct sample batch keys will be available when using the learner
  230. # group API.
  231. if self.config._enable_learner_api:
  232. for k in model.input_specs_train():
  233. train_batch[k]
  234. return None
  235. else:
  236. raise NotImplementedError
  237. @DeveloperAPI
  238. @OverrideToImplementCustomLogic
  239. def action_sampler_fn(
  240. self,
  241. model: ModelV2,
  242. *,
  243. obs_batch: TensorType,
  244. state_batches: TensorType,
  245. **kwargs,
  246. ) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
  247. """Custom function for sampling new actions given policy.
  248. Args:
  249. model: Underlying model.
  250. obs_batch: Observation tensor batch.
  251. state_batches: Action sampling state batch.
  252. Returns:
  253. Sampled action
  254. Log-likelihood
  255. Action distribution inputs
  256. Updated state
  257. """
  258. return None, None, None, None
  259. @DeveloperAPI
  260. @OverrideToImplementCustomLogic
  261. def action_distribution_fn(
  262. self,
  263. model: ModelV2,
  264. *,
  265. obs_batch: TensorType,
  266. state_batches: TensorType,
  267. **kwargs,
  268. ) -> Tuple[TensorType, type, List[TensorType]]:
  269. """Action distribution function for this Policy.
  270. Args:
  271. model: Underlying model.
  272. obs_batch: Observation tensor batch.
  273. state_batches: Action sampling state batch.
  274. Returns:
  275. Distribution input.
  276. ActionDistribution class.
  277. State outs.
  278. """
  279. return None, None, None
  280. @DeveloperAPI
  281. @OverrideToImplementCustomLogic
  282. def make_model(self) -> ModelV2:
  283. """Create model.
  284. Note: only one of make_model or make_model_and_action_dist
  285. can be overridden.
  286. Returns:
  287. ModelV2 model.
  288. """
  289. return None
  290. @ExperimentalAPI
  291. @override(Policy)
  292. def maybe_remove_time_dimension(self, input_dict: Dict[str, TensorType]):
  293. assert self.config.get(
  294. "_enable_learner_api", False
  295. ), "This is a helper method for the new learner API."
  296. if self.config.get("_enable_rl_module_api", False) and self.model.is_stateful():
  297. # Note that this is a temporary workaround to fit the old sampling stack
  298. # to RL Modules.
  299. ret = {}
  300. def fold_mapping(item):
  301. item = torch.as_tensor(item)
  302. size = item.size()
  303. b_dim, t_dim = list(size[:2])
  304. other_dims = list(size[2:])
  305. return item.reshape([b_dim * t_dim] + other_dims)
  306. for k, v in input_dict.items():
  307. if k not in (STATE_IN, STATE_OUT):
  308. ret[k] = tree.map_structure(fold_mapping, v)
  309. else:
  310. # state in already has time dimension.
  311. ret[k] = v
  312. return ret
  313. else:
  314. return input_dict
  315. @DeveloperAPI
  316. @OverrideToImplementCustomLogic
  317. def make_model_and_action_dist(
  318. self,
  319. ) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
  320. """Create model and action distribution function.
  321. Returns:
  322. ModelV2 model.
  323. ActionDistribution class.
  324. """
  325. return None, None
  326. @DeveloperAPI
  327. @OverrideToImplementCustomLogic
  328. def get_batch_divisibility_req(self) -> int:
  329. """Get batch divisibility request.
  330. Returns:
  331. Size N. A sample batch must be of size K*N.
  332. """
  333. # By default, any sized batch is ok, so simply return 1.
  334. return 1
  335. @DeveloperAPI
  336. @OverrideToImplementCustomLogic
  337. def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
  338. """Stats function. Returns a dict of statistics.
  339. Args:
  340. train_batch: The SampleBatch (already) used for training.
  341. Returns:
  342. The stats dict.
  343. """
  344. return {}
  345. @DeveloperAPI
  346. @OverrideToImplementCustomLogic_CallToSuperRecommended
  347. def extra_grad_process(
  348. self, optimizer: "torch.optim.Optimizer", loss: TensorType
  349. ) -> Dict[str, TensorType]:
  350. """Called after each optimizer.zero_grad() + loss.backward() call.
  351. Called for each self._optimizers/loss-value pair.
  352. Allows for gradient processing before optimizer.step() is called.
  353. E.g. for gradient clipping.
  354. Args:
  355. optimizer: A torch optimizer object.
  356. loss: The loss tensor associated with the optimizer.
  357. Returns:
  358. An dict with information on the gradient processing step.
  359. """
  360. return {}
  361. @DeveloperAPI
  362. @OverrideToImplementCustomLogic_CallToSuperRecommended
  363. def extra_compute_grad_fetches(self) -> Dict[str, Any]:
  364. """Extra values to fetch and return from compute_gradients().
  365. Returns:
  366. Extra fetch dict to be added to the fetch dict of the
  367. `compute_gradients` call.
  368. """
  369. return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
  370. @DeveloperAPI
  371. @OverrideToImplementCustomLogic_CallToSuperRecommended
  372. def extra_action_out(
  373. self,
  374. input_dict: Dict[str, TensorType],
  375. state_batches: List[TensorType],
  376. model: TorchModelV2,
  377. action_dist: TorchDistributionWrapper,
  378. ) -> Dict[str, TensorType]:
  379. """Returns dict of extra info to include in experience batch.
  380. Args:
  381. input_dict: Dict of model input tensors.
  382. state_batches: List of state tensors.
  383. model: Reference to the model object.
  384. action_dist: Torch action dist object
  385. to get log-probs (e.g. for already sampled actions).
  386. Returns:
  387. Extra outputs to return in a `compute_actions_from_input_dict()`
  388. call (3rd return value).
  389. """
  390. return {}
  391. @override(Policy)
  392. @DeveloperAPI
  393. @OverrideToImplementCustomLogic_CallToSuperRecommended
  394. def postprocess_trajectory(
  395. self,
  396. sample_batch: SampleBatch,
  397. other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
  398. episode: Optional["Episode"] = None,
  399. ) -> SampleBatch:
  400. """Postprocesses a trajectory and returns the processed trajectory.
  401. The trajectory contains only data from one episode and from one agent.
  402. - If `config.batch_mode=truncate_episodes` (default), sample_batch may
  403. contain a truncated (at-the-end) episode, in case the
  404. `config.rollout_fragment_length` was reached by the sampler.
  405. - If `config.batch_mode=complete_episodes`, sample_batch will contain
  406. exactly one episode (no matter how long).
  407. New columns can be added to sample_batch and existing ones may be altered.
  408. Args:
  409. sample_batch: The SampleBatch to postprocess.
  410. other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
  411. dict of AgentIDs mapping to other agents' trajectory data (from the
  412. same episode). NOTE: The other agents use the same policy.
  413. episode (Optional[Episode]): Optional multi-agent episode
  414. object in which the agents operated.
  415. Returns:
  416. SampleBatch: The postprocessed, modified SampleBatch (or a new one).
  417. """
  418. return sample_batch
  419. @DeveloperAPI
  420. @OverrideToImplementCustomLogic
  421. def optimizer(
  422. self,
  423. ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
  424. """Custom the local PyTorch optimizer(s) to use.
  425. Returns:
  426. The local PyTorch optimizer(s) to use for this Policy.
  427. """
  428. if hasattr(self, "config"):
  429. optimizers = [
  430. torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
  431. ]
  432. else:
  433. optimizers = [torch.optim.Adam(self.model.parameters())]
  434. if self.exploration:
  435. optimizers = self.exploration.get_exploration_optimizer(optimizers)
  436. return optimizers
  437. def _init_model_and_dist_class(self):
  438. if is_overridden(self.make_model) and is_overridden(
  439. self.make_model_and_action_dist
  440. ):
  441. raise ValueError(
  442. "Only one of make_model or make_model_and_action_dist "
  443. "can be overridden."
  444. )
  445. if is_overridden(self.make_model):
  446. model = self.make_model()
  447. dist_class, _ = ModelCatalog.get_action_dist(
  448. self.action_space, self.config["model"], framework=self.framework
  449. )
  450. elif is_overridden(self.make_model_and_action_dist):
  451. model, dist_class = self.make_model_and_action_dist()
  452. else:
  453. dist_class, logit_dim = ModelCatalog.get_action_dist(
  454. self.action_space, self.config["model"], framework=self.framework
  455. )
  456. model = ModelCatalog.get_model_v2(
  457. obs_space=self.observation_space,
  458. action_space=self.action_space,
  459. num_outputs=logit_dim,
  460. model_config=self.config["model"],
  461. framework=self.framework,
  462. )
  463. return model, dist_class
  464. @override(Policy)
  465. def compute_actions_from_input_dict(
  466. self,
  467. input_dict: Dict[str, TensorType],
  468. explore: bool = None,
  469. timestep: Optional[int] = None,
  470. **kwargs,
  471. ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
  472. seq_lens = None
  473. with torch.no_grad():
  474. # Pass lazy (torch) tensor dict to Model as `input_dict`.
  475. input_dict = self._lazy_tensor_dict(input_dict)
  476. input_dict.set_training(True)
  477. if self.config.get("_enable_rl_module_api", False):
  478. return self._compute_action_helper(
  479. input_dict,
  480. state_batches=None,
  481. seq_lens=None,
  482. explore=explore,
  483. timestep=timestep,
  484. )
  485. else:
  486. # Pack internal state inputs into (separate) list.
  487. state_batches = [
  488. input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
  489. ]
  490. # Calculate RNN sequence lengths.
  491. if state_batches:
  492. seq_lens = torch.tensor(
  493. [1] * len(state_batches[0]),
  494. dtype=torch.long,
  495. device=state_batches[0].device,
  496. )
  497. return self._compute_action_helper(
  498. input_dict, state_batches, seq_lens, explore, timestep
  499. )
  500. @override(Policy)
  501. @DeveloperAPI
  502. def compute_actions(
  503. self,
  504. obs_batch: Union[List[TensorStructType], TensorStructType],
  505. state_batches: Optional[List[TensorType]] = None,
  506. prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
  507. prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
  508. info_batch: Optional[Dict[str, list]] = None,
  509. episodes: Optional[List["Episode"]] = None,
  510. explore: Optional[bool] = None,
  511. timestep: Optional[int] = None,
  512. **kwargs,
  513. ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
  514. with torch.no_grad():
  515. seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
  516. input_dict = self._lazy_tensor_dict(
  517. {
  518. SampleBatch.CUR_OBS: obs_batch,
  519. "is_training": False,
  520. }
  521. )
  522. if prev_action_batch is not None:
  523. input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(prev_action_batch)
  524. if prev_reward_batch is not None:
  525. input_dict[SampleBatch.PREV_REWARDS] = np.asarray(prev_reward_batch)
  526. state_batches = [
  527. convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
  528. ]
  529. return self._compute_action_helper(
  530. input_dict, state_batches, seq_lens, explore, timestep
  531. )
  532. @with_lock
  533. @override(Policy)
  534. @DeveloperAPI
  535. def compute_log_likelihoods(
  536. self,
  537. actions: Union[List[TensorStructType], TensorStructType],
  538. obs_batch: Union[List[TensorStructType], TensorStructType],
  539. state_batches: Optional[List[TensorType]] = None,
  540. prev_action_batch: Optional[
  541. Union[List[TensorStructType], TensorStructType]
  542. ] = None,
  543. prev_reward_batch: Optional[
  544. Union[List[TensorStructType], TensorStructType]
  545. ] = None,
  546. actions_normalized: bool = True,
  547. in_training: bool = True,
  548. ) -> TensorType:
  549. if is_overridden(self.action_sampler_fn) and not is_overridden(
  550. self.action_distribution_fn
  551. ):
  552. raise ValueError(
  553. "Cannot compute log-prob/likelihood w/o an "
  554. "`action_distribution_fn` and a provided "
  555. "`action_sampler_fn`!"
  556. )
  557. with torch.no_grad():
  558. input_dict = self._lazy_tensor_dict(
  559. {SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
  560. )
  561. if prev_action_batch is not None:
  562. input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
  563. if prev_reward_batch is not None:
  564. input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
  565. seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
  566. state_batches = [
  567. convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
  568. ]
  569. if self.exploration:
  570. # Exploration hook before each forward pass.
  571. self.exploration.before_compute_actions(explore=False)
  572. # Action dist class and inputs are generated via custom function.
  573. if is_overridden(self.action_distribution_fn):
  574. dist_inputs, dist_class, state_out = self.action_distribution_fn(
  575. self.model,
  576. obs_batch=input_dict,
  577. state_batches=state_batches,
  578. seq_lens=seq_lens,
  579. explore=False,
  580. is_training=False,
  581. )
  582. action_dist = dist_class(dist_inputs, self.model)
  583. # Default action-dist inputs calculation.
  584. else:
  585. if self.config.get("_enable_rl_module_api", False):
  586. if in_training:
  587. output = self.model.forward_train(input_dict)
  588. action_dist_cls = self.model.get_train_action_dist_cls()
  589. if action_dist_cls is None:
  590. raise ValueError(
  591. "The RLModules must provide an appropriate action "
  592. "distribution class for training if is_eval_mode is "
  593. "False."
  594. )
  595. else:
  596. output = self.model.forward_exploration(input_dict)
  597. action_dist_cls = self.model.get_exploration_action_dist_cls()
  598. if action_dist_cls is None:
  599. raise ValueError(
  600. "The RLModules must provide an appropriate action "
  601. "distribution class for exploration if is_eval_mode is "
  602. "True."
  603. )
  604. action_dist_inputs = output.get(
  605. SampleBatch.ACTION_DIST_INPUTS, None
  606. )
  607. if action_dist_inputs is None:
  608. raise ValueError(
  609. "The RLModules must provide inputs to create the action "
  610. "distribution. These should be part of the output of the "
  611. "appropriate forward method under the key "
  612. "SampleBatch.ACTION_DIST_INPUTS."
  613. )
  614. action_dist = action_dist_cls.from_logits(action_dist_inputs)
  615. else:
  616. dist_class = self.dist_class
  617. dist_inputs, _ = self.model(input_dict, state_batches, seq_lens)
  618. action_dist = dist_class(dist_inputs, self.model)
  619. # Normalize actions if necessary.
  620. actions = input_dict[SampleBatch.ACTIONS]
  621. if not actions_normalized and self.config["normalize_actions"]:
  622. actions = normalize_action(actions, self.action_space_struct)
  623. log_likelihoods = action_dist.logp(actions)
  624. return log_likelihoods
  625. @with_lock
  626. @override(Policy)
  627. @DeveloperAPI
  628. def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
  629. # Set Model to train mode.
  630. if self.model:
  631. self.model.train()
  632. # Callback handling.
  633. learn_stats = {}
  634. self.callbacks.on_learn_on_batch(
  635. policy=self, train_batch=postprocessed_batch, result=learn_stats
  636. )
  637. # Compute gradients (will calculate all losses and `backward()`
  638. # them to get the grads).
  639. grads, fetches = self.compute_gradients(postprocessed_batch)
  640. # Step the optimizers.
  641. self.apply_gradients(_directStepOptimizerSingleton)
  642. self.num_grad_updates += 1
  643. if self.model and hasattr(self.model, "metrics"):
  644. fetches["model"] = self.model.metrics()
  645. else:
  646. fetches["model"] = {}
  647. fetches.update(
  648. {
  649. "custom_metrics": learn_stats,
  650. NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
  651. NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
  652. # -1, b/c we have to measure this diff before we do the update above.
  653. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
  654. self.num_grad_updates
  655. - 1
  656. - (postprocessed_batch.num_grad_updates or 0)
  657. ),
  658. }
  659. )
  660. return fetches
  661. @override(Policy)
  662. @DeveloperAPI
  663. def load_batch_into_buffer(
  664. self,
  665. batch: SampleBatch,
  666. buffer_index: int = 0,
  667. ) -> int:
  668. # Set the is_training flag of the batch.
  669. batch.set_training(True)
  670. # Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
  671. if len(self.devices) == 1 and self.devices[0].type == "cpu":
  672. assert buffer_index == 0
  673. pad_batch_to_sequences_of_same_size(
  674. batch=batch,
  675. max_seq_len=self.max_seq_len,
  676. shuffle=False,
  677. batch_divisibility_req=self.batch_divisibility_req,
  678. view_requirements=self.view_requirements,
  679. _enable_rl_module_api=self.config.get("_enable_rl_module_api", False),
  680. padding="last"
  681. if self.config.get("_enable_rl_module_api", False)
  682. else "zero",
  683. )
  684. self._lazy_tensor_dict(batch)
  685. self._loaded_batches[0] = [batch]
  686. return len(batch)
  687. # Batch (len=28, seq-lens=[4, 7, 4, 10, 3]):
  688. # 0123 0123456 0123 0123456789ABC
  689. # 1) split into n per-GPU sub batches (n=2).
  690. # [0123 0123456] [012] [3 0123456789 ABC]
  691. # (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3])
  692. slices = batch.timeslices(num_slices=len(self.devices))
  693. # 2) zero-padding (max-seq-len=10).
  694. # - [0123000000 0123456000 0120000000]
  695. # - [3000000000 0123456789 ABC0000000]
  696. for slice in slices:
  697. pad_batch_to_sequences_of_same_size(
  698. batch=slice,
  699. max_seq_len=self.max_seq_len,
  700. shuffle=False,
  701. batch_divisibility_req=self.batch_divisibility_req,
  702. view_requirements=self.view_requirements,
  703. _enable_rl_module_api=self.config.get("_enable_rl_module_api", False),
  704. padding="last"
  705. if self.config.get("_enable_rl_module_api", False)
  706. else "zero",
  707. )
  708. # 3) Load splits into the given buffer (consisting of n GPUs).
  709. slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)]
  710. self._loaded_batches[buffer_index] = slices
  711. # Return loaded samples per-device.
  712. return len(slices[0])
  713. @override(Policy)
  714. @DeveloperAPI
  715. def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
  716. if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
  717. assert buffer_index == 0
  718. return sum(len(b) for b in self._loaded_batches[buffer_index])
  719. @override(Policy)
  720. @DeveloperAPI
  721. def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
  722. if not self._loaded_batches[buffer_index]:
  723. raise ValueError(
  724. "Must call Policy.load_batch_into_buffer() before "
  725. "Policy.learn_on_loaded_batch()!"
  726. )
  727. # Get the correct slice of the already loaded batch to use,
  728. # based on offset and batch size.
  729. device_batch_size = self.config.get(
  730. "sgd_minibatch_size", self.config["train_batch_size"]
  731. ) // len(self.devices)
  732. # Set Model to train mode.
  733. if self.model_gpu_towers:
  734. for t in self.model_gpu_towers:
  735. t.train()
  736. # Shortcut for 1 CPU only: Batch should already be stored in
  737. # `self._loaded_batches`.
  738. if len(self.devices) == 1 and self.devices[0].type == "cpu":
  739. assert buffer_index == 0
  740. if device_batch_size >= len(self._loaded_batches[0][0]):
  741. batch = self._loaded_batches[0][0]
  742. else:
  743. batch = self._loaded_batches[0][0][offset : offset + device_batch_size]
  744. return self.learn_on_batch(batch)
  745. if len(self.devices) > 1:
  746. # Copy weights of main model (tower-0) to all other towers.
  747. state_dict = self.model.state_dict()
  748. # Just making sure tower-0 is really the same as self.model.
  749. assert self.model_gpu_towers[0] is self.model
  750. for tower in self.model_gpu_towers[1:]:
  751. tower.load_state_dict(state_dict)
  752. if device_batch_size >= sum(len(s) for s in self._loaded_batches[buffer_index]):
  753. device_batches = self._loaded_batches[buffer_index]
  754. else:
  755. device_batches = [
  756. b[offset : offset + device_batch_size]
  757. for b in self._loaded_batches[buffer_index]
  758. ]
  759. # Callback handling.
  760. batch_fetches = {}
  761. for i, batch in enumerate(device_batches):
  762. custom_metrics = {}
  763. self.callbacks.on_learn_on_batch(
  764. policy=self, train_batch=batch, result=custom_metrics
  765. )
  766. batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics}
  767. # Do the (maybe parallelized) gradient calculation step.
  768. tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
  769. # Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
  770. all_grads = []
  771. for i in range(len(tower_outputs[0][0])):
  772. if tower_outputs[0][0][i] is not None:
  773. all_grads.append(
  774. torch.mean(
  775. torch.stack([t[0][i].to(self.device) for t in tower_outputs]),
  776. dim=0,
  777. )
  778. )
  779. else:
  780. all_grads.append(None)
  781. # Set main model's grads to mean-reduced values.
  782. for i, p in enumerate(self.model.parameters()):
  783. p.grad = all_grads[i]
  784. self.apply_gradients(_directStepOptimizerSingleton)
  785. self.num_grad_updates += 1
  786. for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)):
  787. batch_fetches[f"tower_{i}"].update(
  788. {
  789. LEARNER_STATS_KEY: self.stats_fn(batch),
  790. "model": {}
  791. if self.config.get("_enable_rl_module_api", False)
  792. else model.metrics(),
  793. NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
  794. # -1, b/c we have to measure this diff before we do the update
  795. # above.
  796. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
  797. self.num_grad_updates - 1 - (batch.num_grad_updates or 0)
  798. ),
  799. }
  800. )
  801. batch_fetches.update(self.extra_compute_grad_fetches())
  802. return batch_fetches
  803. @with_lock
  804. @override(Policy)
  805. @DeveloperAPI
  806. def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients:
  807. assert len(self.devices) == 1
  808. # If not done yet, see whether we have to zero-pad this batch.
  809. if not postprocessed_batch.zero_padded:
  810. pad_batch_to_sequences_of_same_size(
  811. batch=postprocessed_batch,
  812. max_seq_len=self.max_seq_len,
  813. shuffle=False,
  814. batch_divisibility_req=self.batch_divisibility_req,
  815. view_requirements=self.view_requirements,
  816. _enable_rl_module_api=self.config.get("_enable_rl_module_api", False),
  817. padding="last"
  818. if self.config.get("_enable_rl_module_api", False)
  819. else "zero",
  820. )
  821. postprocessed_batch.set_training(True)
  822. self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
  823. # Do the (maybe parallelized) gradient calculation step.
  824. tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
  825. all_grads, grad_info = tower_outputs[0]
  826. grad_info["allreduce_latency"] /= len(self._optimizers)
  827. grad_info.update(self.stats_fn(postprocessed_batch))
  828. fetches = self.extra_compute_grad_fetches()
  829. return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
  830. @override(Policy)
  831. @DeveloperAPI
  832. def apply_gradients(self, gradients: ModelGradients) -> None:
  833. if gradients == _directStepOptimizerSingleton:
  834. for i, opt in enumerate(self._optimizers):
  835. opt.step()
  836. else:
  837. # TODO(sven): Not supported for multiple optimizers yet.
  838. assert len(self._optimizers) == 1
  839. for g, p in zip(gradients, self.model.parameters()):
  840. if g is not None:
  841. if torch.is_tensor(g):
  842. p.grad = g.to(self.device)
  843. else:
  844. p.grad = torch.from_numpy(g).to(self.device)
  845. self._optimizers[0].step()
  846. @DeveloperAPI
  847. def get_tower_stats(self, stats_name: str) -> List[TensorStructType]:
  848. """Returns list of per-tower stats, copied to this Policy's device.
  849. Args:
  850. stats_name: The name of the stats to average over (this str
  851. must exist as a key inside each tower's `tower_stats` dict).
  852. Returns:
  853. The list of stats tensor (structs) of all towers, copied to this
  854. Policy's device.
  855. Raises:
  856. AssertionError: If the `stats_name` cannot be found in any one
  857. of the tower's `tower_stats` dicts.
  858. """
  859. data = []
  860. for model in self.model_gpu_towers:
  861. if self.tower_stats:
  862. tower_stats = self.tower_stats[model]
  863. else:
  864. tower_stats = model.tower_stats
  865. if stats_name in tower_stats:
  866. data.append(
  867. tree.map_structure(
  868. lambda s: s.to(self.device), tower_stats[stats_name]
  869. )
  870. )
  871. assert len(data) > 0, (
  872. f"Stats `{stats_name}` not found in any of the towers (you have "
  873. f"{len(self.model_gpu_towers)} towers in total)! Make "
  874. "sure you call the loss function on at least one of the towers."
  875. )
  876. return data
  877. @override(Policy)
  878. @DeveloperAPI
  879. def get_weights(self) -> ModelWeights:
  880. return {k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items()}
  881. @override(Policy)
  882. @DeveloperAPI
  883. def set_weights(self, weights: ModelWeights) -> None:
  884. weights = convert_to_torch_tensor(weights, device=self.device)
  885. if self.config.get("_enable_rl_module_api", False):
  886. self.model.set_state(weights)
  887. else:
  888. self.model.load_state_dict(weights)
  889. @override(Policy)
  890. @DeveloperAPI
  891. def is_recurrent(self) -> bool:
  892. return self._is_recurrent
  893. @override(Policy)
  894. @DeveloperAPI
  895. def num_state_tensors(self) -> int:
  896. return len(self.model.get_initial_state())
  897. @override(Policy)
  898. @DeveloperAPI
  899. def get_initial_state(self) -> List[TensorType]:
  900. if self.config.get("_enable_rl_module_api", False):
  901. # convert the tree of tensors to a tree to numpy arrays
  902. return tree.map_structure(
  903. lambda s: convert_to_numpy(s), self.model.get_initial_state()
  904. )
  905. return [s.detach().cpu().numpy() for s in self.model.get_initial_state()]
  906. @override(Policy)
  907. @DeveloperAPI
  908. @OverrideToImplementCustomLogic_CallToSuperRecommended
  909. def get_state(self) -> PolicyState:
  910. # Legacy Policy state (w/o torch.nn.Module and w/o PolicySpec).
  911. state = super().get_state()
  912. state["_optimizer_variables"] = []
  913. # In the new Learner API stack, the optimizers live in the learner.
  914. if not self.config.get("_enable_learner_api", False):
  915. for i, o in enumerate(self._optimizers):
  916. optim_state_dict = convert_to_numpy(o.state_dict())
  917. state["_optimizer_variables"].append(optim_state_dict)
  918. # Add exploration state.
  919. if not self.config.get("_enable_rl_module_api", False) and self.exploration:
  920. # This is not compatible with RLModules, which have a method
  921. # `forward_exploration` to specify custom exploration behavior.
  922. state["_exploration_state"] = self.exploration.get_state()
  923. return state
  924. @override(Policy)
  925. @DeveloperAPI
  926. @OverrideToImplementCustomLogic_CallToSuperRecommended
  927. def set_state(self, state: PolicyState) -> None:
  928. # Set optimizer vars first.
  929. optimizer_vars = state.get("_optimizer_variables", None)
  930. if optimizer_vars:
  931. assert len(optimizer_vars) == len(self._optimizers)
  932. for o, s in zip(self._optimizers, optimizer_vars):
  933. # Torch optimizer param_groups include things like beta, etc. These
  934. # parameters should be left as scalar and not converted to tensors.
  935. # otherwise, torch.optim.step() will start to complain.
  936. optim_state_dict = {"param_groups": s["param_groups"]}
  937. optim_state_dict["state"] = convert_to_torch_tensor(
  938. s["state"], device=self.device
  939. )
  940. o.load_state_dict(optim_state_dict)
  941. # Set exploration's state.
  942. if hasattr(self, "exploration") and "_exploration_state" in state:
  943. self.exploration.set_state(state=state["_exploration_state"])
  944. # Restore glbal timestep.
  945. self.global_timestep = state["global_timestep"]
  946. # Then the Policy's (NN) weights and connectors.
  947. super().set_state(state)
  948. @override(Policy)
  949. @DeveloperAPI
  950. def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
  951. """Exports the Policy's Model to local directory for serving.
  952. Creates a TorchScript model and saves it.
  953. Args:
  954. export_dir: Local writable directory or filename.
  955. onnx: If given, will export model in ONNX format. The
  956. value of this parameter set the ONNX OpSet version to use.
  957. """
  958. os.makedirs(export_dir, exist_ok=True)
  959. enable_rl_module = self.config.get("_enable_rl_module_api", False)
  960. if enable_rl_module and onnx:
  961. raise ValueError("ONNX export not supported for RLModule API.")
  962. if onnx:
  963. self._lazy_tensor_dict(self._dummy_batch)
  964. # Provide dummy state inputs if not an RNN (torch cannot jit with
  965. # returned empty internal states list).
  966. if "state_in_0" not in self._dummy_batch:
  967. self._dummy_batch["state_in_0"] = self._dummy_batch[
  968. SampleBatch.SEQ_LENS
  969. ] = np.array([1.0])
  970. seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS]
  971. state_ins = []
  972. i = 0
  973. while "state_in_{}".format(i) in self._dummy_batch:
  974. state_ins.append(self._dummy_batch["state_in_{}".format(i)])
  975. i += 1
  976. dummy_inputs = {
  977. k: self._dummy_batch[k]
  978. for k in self._dummy_batch.keys()
  979. if k != "is_training"
  980. }
  981. file_name = os.path.join(export_dir, "model.onnx")
  982. torch.onnx.export(
  983. self.model,
  984. (dummy_inputs, state_ins, seq_lens),
  985. file_name,
  986. export_params=True,
  987. opset_version=onnx,
  988. do_constant_folding=True,
  989. input_names=list(dummy_inputs.keys())
  990. + ["state_ins", SampleBatch.SEQ_LENS],
  991. output_names=["output", "state_outs"],
  992. dynamic_axes={
  993. k: {0: "batch_size"}
  994. for k in list(dummy_inputs.keys())
  995. + ["state_ins", SampleBatch.SEQ_LENS]
  996. },
  997. )
  998. # Save the torch.Model (architecture and weights, so it can be retrieved
  999. # w/o access to the original (custom) Model or Policy code).
  1000. else:
  1001. filename = os.path.join(export_dir, "model.pt")
  1002. try:
  1003. torch.save(self.model, f=filename)
  1004. except Exception:
  1005. if os.path.exists(filename):
  1006. os.remove(filename)
  1007. logger.warning(ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL)
  1008. @override(Policy)
  1009. @DeveloperAPI
  1010. def import_model_from_h5(self, import_file: str) -> None:
  1011. """Imports weights into torch model."""
  1012. return self.model.import_from_h5(import_file)
  1013. @with_lock
  1014. def _compute_action_helper(
  1015. self, input_dict, state_batches, seq_lens, explore, timestep
  1016. ):
  1017. """Shared forward pass logic (w/ and w/o trajectory view API).
  1018. Returns:
  1019. A tuple consisting of a) actions, b) state_out, c) extra_fetches.
  1020. The input_dict is modified in-place to include a numpy copy of the computed
  1021. actions under `SampleBatch.ACTIONS`.
  1022. """
  1023. explore = explore if explore is not None else self.config["explore"]
  1024. timestep = timestep if timestep is not None else self.global_timestep
  1025. # Switch to eval mode.
  1026. if self.model:
  1027. self.model.eval()
  1028. extra_fetches = dist_inputs = logp = None
  1029. # New API stack: `self.model` is-a RLModule.
  1030. if isinstance(self.model, RLModule):
  1031. if self.model.is_stateful():
  1032. # For recurrent models, we need to add a time dimension.
  1033. if not seq_lens:
  1034. # In order to calculate the batch size ad hoc, we need a sample
  1035. # batch.
  1036. if not isinstance(input_dict, SampleBatch):
  1037. input_dict = SampleBatch(input_dict)
  1038. seq_lens = np.array([1] * len(input_dict))
  1039. input_dict = self.maybe_add_time_dimension(
  1040. input_dict, seq_lens=seq_lens
  1041. )
  1042. input_dict = convert_to_torch_tensor(input_dict, device=self.device)
  1043. # Batches going into the RL Module should not have seq_lens.
  1044. if SampleBatch.SEQ_LENS in input_dict:
  1045. del input_dict[SampleBatch.SEQ_LENS]
  1046. if explore:
  1047. fwd_out = self.model.forward_exploration(input_dict)
  1048. # For recurrent models, we need to remove the time dimension.
  1049. fwd_out = self.maybe_remove_time_dimension(fwd_out)
  1050. # ACTION_DIST_INPUTS field returned by `forward_exploration()` ->
  1051. # Create a distribution object.
  1052. action_dist = None
  1053. if SampleBatch.ACTION_DIST_INPUTS in fwd_out:
  1054. dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS]
  1055. action_dist_class = self.model.get_exploration_action_dist_cls()
  1056. action_dist = action_dist_class.from_logits(dist_inputs)
  1057. # If `forward_exploration()` returned actions, use them here as-is.
  1058. if SampleBatch.ACTIONS in fwd_out:
  1059. actions = fwd_out[SampleBatch.ACTIONS]
  1060. # Otherwise, sample actions from the distribution.
  1061. else:
  1062. if action_dist is None:
  1063. raise KeyError(
  1064. "Your RLModule's `forward_exploration()` method must return"
  1065. f" a dict with either the {SampleBatch.ACTIONS} key or the "
  1066. f"{SampleBatch.ACTION_DIST_INPUTS} key in it (or both)!"
  1067. )
  1068. actions = action_dist.sample()
  1069. # Compute action-logp and action-prob from distribution and add to
  1070. # `extra_fetches`, if possible.
  1071. if action_dist is not None:
  1072. logp = action_dist.logp(actions)
  1073. else:
  1074. fwd_out = self.model.forward_inference(input_dict)
  1075. # For recurrent models, we need to remove the time dimension.
  1076. fwd_out = self.maybe_remove_time_dimension(fwd_out)
  1077. # ACTION_DIST_INPUTS field returned by `forward_exploration()` ->
  1078. # Create a distribution object.
  1079. action_dist = None
  1080. if SampleBatch.ACTION_DIST_INPUTS in fwd_out:
  1081. dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS]
  1082. action_dist_class = self.model.get_inference_action_dist_cls()
  1083. action_dist = action_dist_class.from_logits(dist_inputs)
  1084. action_dist = action_dist.to_deterministic()
  1085. # If `forward_inference()` returned actions, use them here as-is.
  1086. if SampleBatch.ACTIONS in fwd_out:
  1087. actions = fwd_out[SampleBatch.ACTIONS]
  1088. # Otherwise, sample actions from the distribution.
  1089. else:
  1090. if action_dist is None:
  1091. raise KeyError(
  1092. "Your RLModule's `forward_inference()` method must return"
  1093. f" a dict with either the {SampleBatch.ACTIONS} key or the "
  1094. f"{SampleBatch.ACTION_DIST_INPUTS} key in it (or both)!"
  1095. )
  1096. actions = action_dist.sample()
  1097. # Anything but actions and state_out is an extra fetch.
  1098. state_out = fwd_out.pop(STATE_OUT, {})
  1099. extra_fetches = fwd_out
  1100. elif is_overridden(self.action_sampler_fn):
  1101. action_dist = None
  1102. actions, logp, dist_inputs, state_out = self.action_sampler_fn(
  1103. self.model,
  1104. obs_batch=input_dict,
  1105. state_batches=state_batches,
  1106. explore=explore,
  1107. timestep=timestep,
  1108. )
  1109. else:
  1110. # Call the exploration before_compute_actions hook.
  1111. self.exploration.before_compute_actions(explore=explore, timestep=timestep)
  1112. if is_overridden(self.action_distribution_fn):
  1113. dist_inputs, dist_class, state_out = self.action_distribution_fn(
  1114. self.model,
  1115. obs_batch=input_dict,
  1116. state_batches=state_batches,
  1117. seq_lens=seq_lens,
  1118. explore=explore,
  1119. timestep=timestep,
  1120. is_training=False,
  1121. )
  1122. else:
  1123. dist_class = self.dist_class
  1124. dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
  1125. if not (
  1126. isinstance(dist_class, functools.partial)
  1127. or issubclass(dist_class, TorchDistributionWrapper)
  1128. ):
  1129. raise ValueError(
  1130. "`dist_class` ({}) not a TorchDistributionWrapper "
  1131. "subclass! Make sure your `action_distribution_fn` or "
  1132. "`make_model_and_action_dist` return a correct "
  1133. "distribution class.".format(dist_class.__name__)
  1134. )
  1135. action_dist = dist_class(dist_inputs, self.model)
  1136. # Get the exploration action from the forward results.
  1137. actions, logp = self.exploration.get_exploration_action(
  1138. action_distribution=action_dist, timestep=timestep, explore=explore
  1139. )
  1140. # Add default and custom fetches.
  1141. if extra_fetches is None:
  1142. extra_fetches = self.extra_action_out(
  1143. input_dict, state_batches, self.model, action_dist
  1144. )
  1145. # Action-dist inputs.
  1146. if dist_inputs is not None:
  1147. extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
  1148. # Action-logp and action-prob.
  1149. if logp is not None:
  1150. extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float())
  1151. extra_fetches[SampleBatch.ACTION_LOGP] = logp
  1152. # Update our global timestep by the batch size.
  1153. self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
  1154. return convert_to_numpy((actions, state_out, extra_fetches))
  1155. def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
  1156. if not isinstance(postprocessed_batch, SampleBatch):
  1157. postprocessed_batch = SampleBatch(postprocessed_batch)
  1158. postprocessed_batch.set_get_interceptor(
  1159. functools.partial(convert_to_torch_tensor, device=device or self.device)
  1160. )
  1161. return postprocessed_batch
  1162. def _multi_gpu_parallel_grad_calc(
  1163. self, sample_batches: List[SampleBatch]
  1164. ) -> List[Tuple[List[TensorType], GradInfoDict]]:
  1165. """Performs a parallelized loss and gradient calculation over the batch.
  1166. Splits up the given train batch into n shards (n=number of this
  1167. Policy's devices) and passes each data shard (in parallel) through
  1168. the loss function using the individual devices' models
  1169. (self.model_gpu_towers). Then returns each tower's outputs.
  1170. Args:
  1171. sample_batches: A list of SampleBatch shards to
  1172. calculate loss and gradients for.
  1173. Returns:
  1174. A list (one item per device) of 2-tuples, each with 1) gradient
  1175. list and 2) grad info dict.
  1176. """
  1177. assert len(self.model_gpu_towers) == len(sample_batches)
  1178. lock = threading.Lock()
  1179. results = {}
  1180. grad_enabled = torch.is_grad_enabled()
  1181. def _worker(shard_idx, model, sample_batch, device):
  1182. torch.set_grad_enabled(grad_enabled)
  1183. try:
  1184. with NullContextManager() if device.type == "cpu" else torch.cuda.device( # noqa: E501
  1185. device
  1186. ):
  1187. loss_out = force_list(
  1188. self.loss(model, self.dist_class, sample_batch)
  1189. )
  1190. # Call Model's custom-loss with Policy loss outputs and
  1191. # train_batch.
  1192. if hasattr(model, "custom_loss"):
  1193. loss_out = model.custom_loss(loss_out, sample_batch)
  1194. assert len(loss_out) == len(self._optimizers)
  1195. # Loop through all optimizers.
  1196. grad_info = {"allreduce_latency": 0.0}
  1197. parameters = list(model.parameters())
  1198. all_grads = [None for _ in range(len(parameters))]
  1199. for opt_idx, opt in enumerate(self._optimizers):
  1200. # Erase gradients in all vars of the tower that this
  1201. # optimizer would affect.
  1202. param_indices = self.multi_gpu_param_groups[opt_idx]
  1203. for param_idx, param in enumerate(parameters):
  1204. if param_idx in param_indices and param.grad is not None:
  1205. param.grad.data.zero_()
  1206. # Recompute gradients of loss over all variables.
  1207. loss_out[opt_idx].backward(retain_graph=True)
  1208. grad_info.update(
  1209. self.extra_grad_process(opt, loss_out[opt_idx])
  1210. )
  1211. grads = []
  1212. # Note that return values are just references;
  1213. # Calling zero_grad would modify the values.
  1214. for param_idx, param in enumerate(parameters):
  1215. if param_idx in param_indices:
  1216. if param.grad is not None:
  1217. grads.append(param.grad)
  1218. all_grads[param_idx] = param.grad
  1219. if self.distributed_world_size:
  1220. start = time.time()
  1221. if torch.cuda.is_available():
  1222. # Sadly, allreduce_coalesced does not work with
  1223. # CUDA yet.
  1224. for g in grads:
  1225. torch.distributed.all_reduce(
  1226. g, op=torch.distributed.ReduceOp.SUM
  1227. )
  1228. else:
  1229. torch.distributed.all_reduce_coalesced(
  1230. grads, op=torch.distributed.ReduceOp.SUM
  1231. )
  1232. for param_group in opt.param_groups:
  1233. for p in param_group["params"]:
  1234. if p.grad is not None:
  1235. p.grad /= self.distributed_world_size
  1236. grad_info["allreduce_latency"] += time.time() - start
  1237. with lock:
  1238. results[shard_idx] = (all_grads, grad_info)
  1239. except Exception as e:
  1240. import traceback
  1241. with lock:
  1242. results[shard_idx] = (
  1243. ValueError(
  1244. e.args[0]
  1245. + "\n traceback"
  1246. + traceback.format_exc()
  1247. + "\n"
  1248. + "In tower {} on device {}".format(shard_idx, device)
  1249. ),
  1250. e,
  1251. )
  1252. # Single device (GPU) or fake-GPU case (serialize for better
  1253. # debugging).
  1254. if len(self.devices) == 1 or self.config["_fake_gpus"]:
  1255. for shard_idx, (model, sample_batch, device) in enumerate(
  1256. zip(self.model_gpu_towers, sample_batches, self.devices)
  1257. ):
  1258. _worker(shard_idx, model, sample_batch, device)
  1259. # Raise errors right away for better debugging.
  1260. last_result = results[len(results) - 1]
  1261. if isinstance(last_result[0], ValueError):
  1262. raise last_result[0] from last_result[1]
  1263. # Multi device (GPU) case: Parallelize via threads.
  1264. else:
  1265. threads = [
  1266. threading.Thread(
  1267. target=_worker, args=(shard_idx, model, sample_batch, device)
  1268. )
  1269. for shard_idx, (model, sample_batch, device) in enumerate(
  1270. zip(self.model_gpu_towers, sample_batches, self.devices)
  1271. )
  1272. ]
  1273. for thread in threads:
  1274. thread.start()
  1275. for thread in threads:
  1276. thread.join()
  1277. # Gather all threads' outputs and return.
  1278. outputs = []
  1279. for shard_idx in range(len(sample_batches)):
  1280. output = results[shard_idx]
  1281. if isinstance(output[0], Exception):
  1282. raise output[0] from output[1]
  1283. outputs.append(results[shard_idx])
  1284. return outputs