torch_policy.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190
  1. import copy
  2. import functools
  3. import gym
  4. import logging
  5. import math
  6. import numpy as np
  7. import os
  8. import threading
  9. import time
  10. import tree # pip install dm_tree
  11. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, \
  12. Union, TYPE_CHECKING
  13. import ray
  14. from ray.rllib.models.catalog import ModelCatalog
  15. from ray.rllib.models.modelv2 import ModelV2
  16. from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
  17. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  18. from ray.rllib.policy.policy import Policy
  19. from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
  20. from ray.rllib.policy.sample_batch import SampleBatch
  21. from ray.rllib.utils import force_list, NullContextManager
  22. from ray.rllib.utils.annotations import DeveloperAPI, override
  23. from ray.rllib.utils.framework import try_import_torch
  24. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  25. from ray.rllib.utils.numpy import convert_to_numpy
  26. from ray.rllib.utils.schedules import PiecewiseSchedule
  27. from ray.rllib.utils.spaces.space_utils import normalize_action
  28. from ray.rllib.utils.threading import with_lock
  29. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  30. from ray.rllib.utils.typing import GradInfoDict, ModelGradients, \
  31. ModelWeights, TensorType, TensorStructType, TrainerConfigDict
  32. if TYPE_CHECKING:
  33. from ray.rllib.evaluation import Episode # noqa
  34. torch, nn = try_import_torch()
  35. logger = logging.getLogger(__name__)
  36. @DeveloperAPI
  37. class TorchPolicy(Policy):
  38. """PyTorch specific Policy class to use with RLlib."""
  39. @DeveloperAPI
  40. def __init__(
  41. self,
  42. observation_space: gym.spaces.Space,
  43. action_space: gym.spaces.Space,
  44. config: TrainerConfigDict,
  45. *,
  46. model: Optional[TorchModelV2] = None,
  47. loss: Optional[Callable[[
  48. Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
  49. ], Union[TensorType, List[TensorType]]]] = None,
  50. action_distribution_class: Optional[Type[
  51. TorchDistributionWrapper]] = None,
  52. action_sampler_fn: Optional[Callable[[
  53. TensorType, List[TensorType]
  54. ], Tuple[TensorType, TensorType]]] = None,
  55. action_distribution_fn: Optional[Callable[[
  56. Policy, ModelV2, TensorType, TensorType, TensorType
  57. ], Tuple[TensorType, Type[TorchDistributionWrapper], List[
  58. TensorType]]]] = None,
  59. max_seq_len: int = 20,
  60. get_batch_divisibility_req: Optional[Callable[[Policy],
  61. int]] = None,
  62. ):
  63. """Initializes a TorchPolicy instance.
  64. Args:
  65. observation_space: Observation space of the policy.
  66. action_space: Action space of the policy.
  67. config: The Policy's config dict.
  68. model: PyTorch policy module. Given observations as
  69. input, this module must return a list of outputs where the
  70. first item is action logits, and the rest can be any value.
  71. loss: Callable that returns one or more (a list of) scalar loss
  72. terms.
  73. action_distribution_class: Class for a torch action distribution.
  74. action_sampler_fn: A callable returning a sampled action and its
  75. log-likelihood given Policy, ModelV2, input_dict, state batches
  76. (optional), explore, and timestep.
  77. Provide `action_sampler_fn` if you would like to have full
  78. control over the action computation step, including the
  79. model forward pass, possible sampling from a distribution,
  80. and exploration logic.
  81. Note: If `action_sampler_fn` is given, `action_distribution_fn`
  82. must be None. If both `action_sampler_fn` and
  83. `action_distribution_fn` are None, RLlib will simply pass
  84. inputs through `self.model` to get distribution inputs, create
  85. the distribution object, sample from it, and apply some
  86. exploration logic to the results.
  87. The callable takes as inputs: Policy, ModelV2, input_dict
  88. (SampleBatch), state_batches (optional), explore, and timestep.
  89. action_distribution_fn: A callable returning distribution inputs
  90. (parameters), a dist-class to generate an action distribution
  91. object from, and internal-state outputs (or an empty list if
  92. not applicable).
  93. Provide `action_distribution_fn` if you would like to only
  94. customize the model forward pass call. The resulting
  95. distribution parameters are then used by RLlib to create a
  96. distribution object, sample from it, and execute any
  97. exploration logic.
  98. Note: If `action_distribution_fn` is given, `action_sampler_fn`
  99. must be None. If both `action_sampler_fn` and
  100. `action_distribution_fn` are None, RLlib will simply pass
  101. inputs through `self.model` to get distribution inputs, create
  102. the distribution object, sample from it, and apply some
  103. exploration logic to the results.
  104. The callable takes as inputs: Policy, ModelV2, ModelInputDict,
  105. explore, timestep, is_training.
  106. max_seq_len: Max sequence length for LSTM training.
  107. get_batch_divisibility_req: Optional callable that returns the
  108. divisibility requirement for sample batches given the Policy.
  109. """
  110. self.framework = config["framework"] = "torch"
  111. super().__init__(observation_space, action_space, config)
  112. # Create multi-GPU model towers, if necessary.
  113. # - The central main model will be stored under self.model, residing
  114. # on self.device (normally, a CPU).
  115. # - Each GPU will have a copy of that model under
  116. # self.model_gpu_towers, matching the devices in self.devices.
  117. # - Parallelization is done by splitting the train batch and passing
  118. # it through the model copies in parallel, then averaging over the
  119. # resulting gradients, applying these averages on the main model and
  120. # updating all towers' weights from the main model.
  121. # - In case of just one device (1 (fake or real) GPU or 1 CPU), no
  122. # parallelization will be done.
  123. # If no Model is provided, build a default one here.
  124. if model is None:
  125. dist_class, logit_dim = ModelCatalog.get_action_dist(
  126. action_space, self.config["model"], framework=self.framework)
  127. model = ModelCatalog.get_model_v2(
  128. obs_space=self.observation_space,
  129. action_space=self.action_space,
  130. num_outputs=logit_dim,
  131. model_config=self.config["model"],
  132. framework=self.framework)
  133. if action_distribution_class is None:
  134. action_distribution_class = dist_class
  135. # Get devices to build the graph on.
  136. worker_idx = self.config.get("worker_index", 0)
  137. if not config["_fake_gpus"] and \
  138. ray.worker._mode() == ray.worker.LOCAL_MODE:
  139. num_gpus = 0
  140. elif worker_idx == 0:
  141. num_gpus = config["num_gpus"]
  142. else:
  143. num_gpus = config["num_gpus_per_worker"]
  144. gpu_ids = list(range(torch.cuda.device_count()))
  145. # Place on one or more CPU(s) when either:
  146. # - Fake GPU mode.
  147. # - num_gpus=0 (either set by user or we are in local_mode=True).
  148. # - No GPUs available.
  149. if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
  150. logger.info("TorchPolicy (worker={}) running on {}.".format(
  151. worker_idx
  152. if worker_idx > 0 else "local", "{} fake-GPUs".format(num_gpus)
  153. if config["_fake_gpus"] else "CPU"))
  154. self.device = torch.device("cpu")
  155. self.devices = [
  156. self.device for _ in range(int(math.ceil(num_gpus)) or 1)
  157. ]
  158. self.model_gpu_towers = [
  159. model if i == 0 else copy.deepcopy(model)
  160. for i in range(int(math.ceil(num_gpus)) or 1)
  161. ]
  162. if hasattr(self, "target_model"):
  163. self.target_models = {
  164. m: self.target_model
  165. for m in self.model_gpu_towers
  166. }
  167. self.model = model
  168. # Place on one or more actual GPU(s), when:
  169. # - num_gpus > 0 (set by user) AND
  170. # - local_mode=False AND
  171. # - actual GPUs available AND
  172. # - non-fake GPU mode.
  173. else:
  174. logger.info("TorchPolicy (worker={}) running on {} GPU(s).".format(
  175. worker_idx if worker_idx > 0 else "local", num_gpus))
  176. # We are a remote worker (WORKER_MODE=1):
  177. # GPUs should be assigned to us by ray.
  178. if ray.worker._mode() == ray.worker.WORKER_MODE:
  179. gpu_ids = ray.get_gpu_ids()
  180. if len(gpu_ids) < num_gpus:
  181. raise ValueError(
  182. "TorchPolicy was not able to find enough GPU IDs! Found "
  183. f"{gpu_ids}, but num_gpus={num_gpus}.")
  184. self.devices = [
  185. torch.device("cuda:{}".format(i))
  186. for i, id_ in enumerate(gpu_ids) if i < num_gpus
  187. ]
  188. self.device = self.devices[0]
  189. ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
  190. self.model_gpu_towers = []
  191. for i, _ in enumerate(ids):
  192. model_copy = copy.deepcopy(model)
  193. self.model_gpu_towers.append(model_copy.to(self.devices[i]))
  194. if hasattr(self, "target_model"):
  195. self.target_models = {
  196. m: copy.deepcopy(self.target_model).to(self.devices[i])
  197. for i, m in enumerate(self.model_gpu_towers)
  198. }
  199. self.model = self.model_gpu_towers[0]
  200. # Lock used for locking some methods on the object-level.
  201. # This prevents possible race conditions when calling the model
  202. # first, then its value function (e.g. in a loss function), in
  203. # between of which another model call is made (e.g. to compute an
  204. # action).
  205. self._lock = threading.RLock()
  206. self._state_inputs = self.model.get_initial_state()
  207. self._is_recurrent = len(self._state_inputs) > 0
  208. # Auto-update model's inference view requirements, if recurrent.
  209. self._update_model_view_requirements_from_init_state()
  210. # Combine view_requirements for Model and Policy.
  211. self.view_requirements.update(self.model.view_requirements)
  212. self.exploration = self._create_exploration()
  213. self.unwrapped_model = model # used to support DistributedDataParallel
  214. # To ensure backward compatibility:
  215. # Old way: If `loss` provided here, use as-is (as a function).
  216. if loss is not None:
  217. self._loss = loss
  218. # New way: Convert the overridden `self.loss` into a plain function,
  219. # so it can be called the same way as `loss` would be, ensuring
  220. # backward compatibility.
  221. elif self.loss.__func__.__qualname__ != "Policy.loss":
  222. self._loss = self.loss.__func__
  223. # `loss` not provided nor overridden from Policy -> Set to None.
  224. else:
  225. self._loss = None
  226. self._optimizers = force_list(self.optimizer())
  227. # Store, which params (by index within the model's list of
  228. # parameters) should be updated per optimizer.
  229. # Maps optimizer idx to set or param indices.
  230. self.multi_gpu_param_groups: List[Set[int]] = []
  231. main_params = {p: i for i, p in enumerate(self.model.parameters())}
  232. for o in self._optimizers:
  233. param_indices = []
  234. for pg_idx, pg in enumerate(o.param_groups):
  235. for p in pg["params"]:
  236. param_indices.append(main_params[p])
  237. self.multi_gpu_param_groups.append(set(param_indices))
  238. # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
  239. # one with m towers (num_gpus).
  240. num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
  241. self._loaded_batches = [[] for _ in range(num_buffers)]
  242. self.dist_class = action_distribution_class
  243. self.action_sampler_fn = action_sampler_fn
  244. self.action_distribution_fn = action_distribution_fn
  245. # If set, means we are using distributed allreduce during learning.
  246. self.distributed_world_size = None
  247. self.max_seq_len = max_seq_len
  248. self.batch_divisibility_req = get_batch_divisibility_req(self) if \
  249. callable(get_batch_divisibility_req) else \
  250. (get_batch_divisibility_req or 1)
  251. @override(Policy)
  252. def compute_actions_from_input_dict(
  253. self,
  254. input_dict: Dict[str, TensorType],
  255. explore: bool = None,
  256. timestep: Optional[int] = None,
  257. **kwargs) -> \
  258. Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
  259. with torch.no_grad():
  260. # Pass lazy (torch) tensor dict to Model as `input_dict`.
  261. input_dict = self._lazy_tensor_dict(input_dict)
  262. input_dict.set_training(True)
  263. # Pack internal state inputs into (separate) list.
  264. state_batches = [
  265. input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
  266. ]
  267. # Calculate RNN sequence lengths.
  268. seq_lens = np.array([1] * len(input_dict["obs"])) \
  269. if state_batches else None
  270. return self._compute_action_helper(input_dict, state_batches,
  271. seq_lens, explore, timestep)
  272. @override(Policy)
  273. @DeveloperAPI
  274. def compute_actions(
  275. self,
  276. obs_batch: Union[List[TensorStructType], TensorStructType],
  277. state_batches: Optional[List[TensorType]] = None,
  278. prev_action_batch: Union[List[TensorStructType],
  279. TensorStructType] = None,
  280. prev_reward_batch: Union[List[TensorStructType],
  281. TensorStructType] = None,
  282. info_batch: Optional[Dict[str, list]] = None,
  283. episodes: Optional[List["Episode"]] = None,
  284. explore: Optional[bool] = None,
  285. timestep: Optional[int] = None,
  286. **kwargs) -> \
  287. Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
  288. with torch.no_grad():
  289. seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
  290. input_dict = self._lazy_tensor_dict({
  291. SampleBatch.CUR_OBS: obs_batch,
  292. "is_training": False,
  293. })
  294. if prev_action_batch is not None:
  295. input_dict[SampleBatch.PREV_ACTIONS] = \
  296. np.asarray(prev_action_batch)
  297. if prev_reward_batch is not None:
  298. input_dict[SampleBatch.PREV_REWARDS] = \
  299. np.asarray(prev_reward_batch)
  300. state_batches = [
  301. convert_to_torch_tensor(s, self.device)
  302. for s in (state_batches or [])
  303. ]
  304. return self._compute_action_helper(input_dict, state_batches,
  305. seq_lens, explore, timestep)
  306. @with_lock
  307. @override(Policy)
  308. @DeveloperAPI
  309. def compute_log_likelihoods(
  310. self,
  311. actions: Union[List[TensorStructType], TensorStructType],
  312. obs_batch: Union[List[TensorStructType], TensorStructType],
  313. state_batches: Optional[List[TensorType]] = None,
  314. prev_action_batch: Optional[Union[List[TensorStructType],
  315. TensorStructType]] = None,
  316. prev_reward_batch: Optional[Union[List[TensorStructType],
  317. TensorStructType]] = None,
  318. actions_normalized: bool = True,
  319. ) -> TensorType:
  320. if self.action_sampler_fn and self.action_distribution_fn is None:
  321. raise ValueError("Cannot compute log-prob/likelihood w/o an "
  322. "`action_distribution_fn` and a provided "
  323. "`action_sampler_fn`!")
  324. with torch.no_grad():
  325. input_dict = self._lazy_tensor_dict({
  326. SampleBatch.CUR_OBS: obs_batch,
  327. SampleBatch.ACTIONS: actions
  328. })
  329. if prev_action_batch is not None:
  330. input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
  331. if prev_reward_batch is not None:
  332. input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
  333. seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
  334. state_batches = [
  335. convert_to_torch_tensor(s, self.device)
  336. for s in (state_batches or [])
  337. ]
  338. # Exploration hook before each forward pass.
  339. self.exploration.before_compute_actions(explore=False)
  340. # Action dist class and inputs are generated via custom function.
  341. if self.action_distribution_fn:
  342. # Try new action_distribution_fn signature, supporting
  343. # state_batches and seq_lens.
  344. try:
  345. dist_inputs, dist_class, state_out = \
  346. self.action_distribution_fn(
  347. self,
  348. self.model,
  349. input_dict=input_dict,
  350. state_batches=state_batches,
  351. seq_lens=seq_lens,
  352. explore=False,
  353. is_training=False)
  354. # Trying the old way (to stay backward compatible).
  355. # TODO: Remove in future.
  356. except TypeError as e:
  357. if "positional argument" in e.args[0] or \
  358. "unexpected keyword argument" in e.args[0]:
  359. dist_inputs, dist_class, _ = \
  360. self.action_distribution_fn(
  361. policy=self,
  362. model=self.model,
  363. obs_batch=input_dict[SampleBatch.CUR_OBS],
  364. explore=False,
  365. is_training=False)
  366. else:
  367. raise e
  368. # Default action-dist inputs calculation.
  369. else:
  370. dist_class = self.dist_class
  371. dist_inputs, _ = self.model(input_dict, state_batches,
  372. seq_lens)
  373. action_dist = dist_class(dist_inputs, self.model)
  374. # Normalize actions if necessary.
  375. actions = input_dict[SampleBatch.ACTIONS]
  376. if not actions_normalized and self.config["normalize_actions"]:
  377. actions = normalize_action(actions, self.action_space_struct)
  378. log_likelihoods = action_dist.logp(actions)
  379. return log_likelihoods
  380. @with_lock
  381. @override(Policy)
  382. @DeveloperAPI
  383. def learn_on_batch(
  384. self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
  385. # Set Model to train mode.
  386. if self.model:
  387. self.model.train()
  388. # Callback handling.
  389. learn_stats = {}
  390. self.callbacks.on_learn_on_batch(
  391. policy=self, train_batch=postprocessed_batch, result=learn_stats)
  392. # Compute gradients (will calculate all losses and `backward()`
  393. # them to get the grads).
  394. grads, fetches = self.compute_gradients(postprocessed_batch)
  395. # Step the optimizers.
  396. self.apply_gradients(_directStepOptimizerSingleton)
  397. if self.model:
  398. fetches["model"] = self.model.metrics()
  399. fetches.update({"custom_metrics": learn_stats})
  400. return fetches
  401. @override(Policy)
  402. @DeveloperAPI
  403. def load_batch_into_buffer(
  404. self,
  405. batch: SampleBatch,
  406. buffer_index: int = 0,
  407. ) -> int:
  408. # Set the is_training flag of the batch.
  409. batch.set_training(True)
  410. # Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
  411. if len(self.devices) == 1 and self.devices[0].type == "cpu":
  412. assert buffer_index == 0
  413. pad_batch_to_sequences_of_same_size(
  414. batch=batch,
  415. max_seq_len=self.max_seq_len,
  416. shuffle=False,
  417. batch_divisibility_req=self.batch_divisibility_req,
  418. view_requirements=self.view_requirements,
  419. )
  420. self._lazy_tensor_dict(batch)
  421. self._loaded_batches[0] = [batch]
  422. return len(batch)
  423. # Batch (len=28, seq-lens=[4, 7, 4, 10, 3]):
  424. # 0123 0123456 0123 0123456789ABC
  425. # 1) split into n per-GPU sub batches (n=2).
  426. # [0123 0123456] [012] [3 0123456789 ABC]
  427. # (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3])
  428. slices = batch.timeslices(num_slices=len(self.devices))
  429. # 2) zero-padding (max-seq-len=10).
  430. # - [0123000000 0123456000 0120000000]
  431. # - [3000000000 0123456789 ABC0000000]
  432. for slice in slices:
  433. pad_batch_to_sequences_of_same_size(
  434. batch=slice,
  435. max_seq_len=self.max_seq_len,
  436. shuffle=False,
  437. batch_divisibility_req=self.batch_divisibility_req,
  438. view_requirements=self.view_requirements,
  439. )
  440. # 3) Load splits into the given buffer (consisting of n GPUs).
  441. slices = [
  442. slice.to_device(self.devices[i]) for i, slice in enumerate(slices)
  443. ]
  444. self._loaded_batches[buffer_index] = slices
  445. # Return loaded samples per-device.
  446. return len(slices[0])
  447. @override(Policy)
  448. @DeveloperAPI
  449. def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
  450. if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
  451. assert buffer_index == 0
  452. return len(self._loaded_batches[buffer_index])
  453. @override(Policy)
  454. @DeveloperAPI
  455. def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
  456. if not self._loaded_batches[buffer_index]:
  457. raise ValueError(
  458. "Must call Policy.load_batch_into_buffer() before "
  459. "Policy.learn_on_loaded_batch()!")
  460. # Get the correct slice of the already loaded batch to use,
  461. # based on offset and batch size.
  462. device_batch_size = \
  463. self.config.get(
  464. "sgd_minibatch_size", self.config["train_batch_size"]) // \
  465. len(self.devices)
  466. # Set Model to train mode.
  467. if self.model_gpu_towers:
  468. for t in self.model_gpu_towers:
  469. t.train()
  470. # Shortcut for 1 CPU only: Batch should already be stored in
  471. # `self._loaded_batches`.
  472. if len(self.devices) == 1 and self.devices[0].type == "cpu":
  473. assert buffer_index == 0
  474. if device_batch_size >= len(self._loaded_batches[0][0]):
  475. batch = self._loaded_batches[0][0]
  476. else:
  477. batch = self._loaded_batches[0][0][offset:offset +
  478. device_batch_size]
  479. return self.learn_on_batch(batch)
  480. if len(self.devices) > 1:
  481. # Copy weights of main model (tower-0) to all other towers.
  482. state_dict = self.model.state_dict()
  483. # Just making sure tower-0 is really the same as self.model.
  484. assert self.model_gpu_towers[0] is self.model
  485. for tower in self.model_gpu_towers[1:]:
  486. tower.load_state_dict(state_dict)
  487. if device_batch_size >= sum(
  488. len(s) for s in self._loaded_batches[buffer_index]):
  489. device_batches = self._loaded_batches[buffer_index]
  490. else:
  491. device_batches = [
  492. b[offset:offset + device_batch_size]
  493. for b in self._loaded_batches[buffer_index]
  494. ]
  495. # Do the (maybe parallelized) gradient calculation step.
  496. tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
  497. # Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
  498. all_grads = []
  499. for i in range(len(tower_outputs[0][0])):
  500. if tower_outputs[0][0][i] is not None:
  501. all_grads.append(
  502. torch.mean(
  503. torch.stack(
  504. [t[0][i].to(self.device) for t in tower_outputs]),
  505. dim=0))
  506. else:
  507. all_grads.append(None)
  508. # Set main model's grads to mean-reduced values.
  509. for i, p in enumerate(self.model.parameters()):
  510. p.grad = all_grads[i]
  511. self.apply_gradients(_directStepOptimizerSingleton)
  512. batch_fetches = {}
  513. for i, batch in enumerate(device_batches):
  514. batch_fetches[f"tower_{i}"] = {
  515. LEARNER_STATS_KEY: self.extra_grad_info(batch)
  516. }
  517. batch_fetches.update(self.extra_compute_grad_fetches())
  518. return batch_fetches
  519. @with_lock
  520. @override(Policy)
  521. @DeveloperAPI
  522. def compute_gradients(self,
  523. postprocessed_batch: SampleBatch) -> ModelGradients:
  524. assert len(self.devices) == 1
  525. # If not done yet, see whether we have to zero-pad this batch.
  526. if not postprocessed_batch.zero_padded:
  527. pad_batch_to_sequences_of_same_size(
  528. batch=postprocessed_batch,
  529. max_seq_len=self.max_seq_len,
  530. shuffle=False,
  531. batch_divisibility_req=self.batch_divisibility_req,
  532. view_requirements=self.view_requirements,
  533. )
  534. postprocessed_batch.set_training(True)
  535. self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
  536. # Do the (maybe parallelized) gradient calculation step.
  537. tower_outputs = self._multi_gpu_parallel_grad_calc(
  538. [postprocessed_batch])
  539. all_grads, grad_info = tower_outputs[0]
  540. grad_info["allreduce_latency"] /= len(self._optimizers)
  541. grad_info.update(self.extra_grad_info(postprocessed_batch))
  542. fetches = self.extra_compute_grad_fetches()
  543. return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
  544. @override(Policy)
  545. @DeveloperAPI
  546. def apply_gradients(self, gradients: ModelGradients) -> None:
  547. if gradients == _directStepOptimizerSingleton:
  548. for i, opt in enumerate(self._optimizers):
  549. opt.step()
  550. else:
  551. # TODO(sven): Not supported for multiple optimizers yet.
  552. assert len(self._optimizers) == 1
  553. for g, p in zip(gradients, self.model.parameters()):
  554. if g is not None:
  555. if torch.is_tensor(g):
  556. p.grad = g.to(self.device)
  557. else:
  558. p.grad = torch.from_numpy(g).to(self.device)
  559. self._optimizers[0].step()
  560. @DeveloperAPI
  561. def get_tower_stats(self, stats_name: str) -> List[TensorStructType]:
  562. """Returns list of per-tower stats, copied to this Policy's device.
  563. Args:
  564. stats_name: The name of the stats to average over (this str
  565. must exist as a key inside each tower's `tower_stats` dict).
  566. Returns:
  567. The list of stats tensor (structs) of all towers, copied to this
  568. Policy's device.
  569. Raises:
  570. AssertionError: If the `stats_name` cannot be found in any one
  571. of the tower's `tower_stats` dicts.
  572. """
  573. data = []
  574. for tower in self.model_gpu_towers:
  575. if stats_name in tower.tower_stats:
  576. data.append(
  577. tree.map_structure(lambda s: s.to(self.device),
  578. tower.tower_stats[stats_name]))
  579. assert len(data) > 0, \
  580. f"Stats `{stats_name}` not found in any of the towers (you have " \
  581. f"{len(self.model_gpu_towers)} towers in total)! Make " \
  582. "sure you call the loss function on at least one of the towers."
  583. return data
  584. @override(Policy)
  585. @DeveloperAPI
  586. def get_weights(self) -> ModelWeights:
  587. return {
  588. k: v.cpu().detach().numpy()
  589. for k, v in self.model.state_dict().items()
  590. }
  591. @override(Policy)
  592. @DeveloperAPI
  593. def set_weights(self, weights: ModelWeights) -> None:
  594. weights = convert_to_torch_tensor(weights, device=self.device)
  595. self.model.load_state_dict(weights)
  596. @override(Policy)
  597. @DeveloperAPI
  598. def is_recurrent(self) -> bool:
  599. return self._is_recurrent
  600. @override(Policy)
  601. @DeveloperAPI
  602. def num_state_tensors(self) -> int:
  603. return len(self.model.get_initial_state())
  604. @override(Policy)
  605. @DeveloperAPI
  606. def get_initial_state(self) -> List[TensorType]:
  607. return [
  608. s.detach().cpu().numpy() for s in self.model.get_initial_state()
  609. ]
  610. @override(Policy)
  611. @DeveloperAPI
  612. def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
  613. state = super().get_state()
  614. state["_optimizer_variables"] = []
  615. for i, o in enumerate(self._optimizers):
  616. optim_state_dict = convert_to_numpy(o.state_dict())
  617. state["_optimizer_variables"].append(optim_state_dict)
  618. # Add exploration state.
  619. state["_exploration_state"] = \
  620. self.exploration.get_state()
  621. return state
  622. @override(Policy)
  623. @DeveloperAPI
  624. def set_state(self, state: dict) -> None:
  625. # Set optimizer vars first.
  626. optimizer_vars = state.get("_optimizer_variables", None)
  627. if optimizer_vars:
  628. assert len(optimizer_vars) == len(self._optimizers)
  629. for o, s in zip(self._optimizers, optimizer_vars):
  630. optim_state_dict = convert_to_torch_tensor(
  631. s, device=self.device)
  632. o.load_state_dict(optim_state_dict)
  633. # Set exploration's state.
  634. if hasattr(self, "exploration") and "_exploration_state" in state:
  635. self.exploration.set_state(state=state["_exploration_state"])
  636. # Then the Policy's (NN) weights.
  637. super().set_state(state)
  638. @DeveloperAPI
  639. def extra_grad_process(self, optimizer: "torch.optim.Optimizer",
  640. loss: TensorType) -> Dict[str, TensorType]:
  641. """Called after each optimizer.zero_grad() + loss.backward() call.
  642. Called for each self._optimizers/loss-value pair.
  643. Allows for gradient processing before optimizer.step() is called.
  644. E.g. for gradient clipping.
  645. Args:
  646. optimizer: A torch optimizer object.
  647. loss: The loss tensor associated with the optimizer.
  648. Returns:
  649. An dict with information on the gradient processing step.
  650. """
  651. return {}
  652. @DeveloperAPI
  653. def extra_compute_grad_fetches(self) -> Dict[str, Any]:
  654. """Extra values to fetch and return from compute_gradients().
  655. Returns:
  656. Extra fetch dict to be added to the fetch dict of the
  657. `compute_gradients` call.
  658. """
  659. return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
  660. @DeveloperAPI
  661. def extra_action_out(
  662. self, input_dict: Dict[str, TensorType],
  663. state_batches: List[TensorType], model: TorchModelV2,
  664. action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
  665. """Returns dict of extra info to include in experience batch.
  666. Args:
  667. input_dict: Dict of model input tensors.
  668. state_batches: List of state tensors.
  669. model: Reference to the model object.
  670. action_dist: Torch action dist object
  671. to get log-probs (e.g. for already sampled actions).
  672. Returns:
  673. Extra outputs to return in a `compute_actions_from_input_dict()`
  674. call (3rd return value).
  675. """
  676. return {}
  677. @DeveloperAPI
  678. def extra_grad_info(self,
  679. train_batch: SampleBatch) -> Dict[str, TensorType]:
  680. """Return dict of extra grad info.
  681. Args:
  682. train_batch: The training batch for which to produce
  683. extra grad info for.
  684. Returns:
  685. The info dict carrying grad info per str key.
  686. """
  687. return {}
  688. @DeveloperAPI
  689. def optimizer(
  690. self
  691. ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
  692. """Custom the local PyTorch optimizer(s) to use.
  693. Returns:
  694. The local PyTorch optimizer(s) to use for this Policy.
  695. """
  696. if hasattr(self, "config"):
  697. optimizers = [
  698. torch.optim.Adam(
  699. self.model.parameters(), lr=self.config["lr"])
  700. ]
  701. else:
  702. optimizers = [torch.optim.Adam(self.model.parameters())]
  703. if getattr(self, "exploration", None):
  704. optimizers = self.exploration.get_exploration_optimizer(optimizers)
  705. return optimizers
  706. @override(Policy)
  707. @DeveloperAPI
  708. def export_model(self, export_dir: str,
  709. onnx: Optional[int] = None) -> None:
  710. """Exports the Policy's Model to local directory for serving.
  711. Creates a TorchScript model and saves it.
  712. Args:
  713. export_dir: Local writable directory or filename.
  714. onnx: If given, will export model in ONNX format. The
  715. value of this parameter set the ONNX OpSet version to use.
  716. """
  717. self._lazy_tensor_dict(self._dummy_batch)
  718. # Provide dummy state inputs if not an RNN (torch cannot jit with
  719. # returned empty internal states list).
  720. if "state_in_0" not in self._dummy_batch:
  721. self._dummy_batch["state_in_0"] = \
  722. self._dummy_batch[SampleBatch.SEQ_LENS] = np.array([1.0])
  723. state_ins = []
  724. i = 0
  725. while "state_in_{}".format(i) in self._dummy_batch:
  726. state_ins.append(self._dummy_batch["state_in_{}".format(i)])
  727. i += 1
  728. dummy_inputs = {
  729. k: self._dummy_batch[k]
  730. for k in self._dummy_batch.keys() if k != "is_training"
  731. }
  732. if not os.path.exists(export_dir):
  733. os.makedirs(export_dir)
  734. seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS]
  735. if onnx:
  736. file_name = os.path.join(export_dir, "model.onnx")
  737. torch.onnx.export(
  738. self.model, (dummy_inputs, state_ins, seq_lens),
  739. file_name,
  740. export_params=True,
  741. opset_version=onnx,
  742. do_constant_folding=True,
  743. input_names=list(dummy_inputs.keys()) +
  744. ["state_ins", SampleBatch.SEQ_LENS],
  745. output_names=["output", "state_outs"],
  746. dynamic_axes={
  747. k: {
  748. 0: "batch_size"
  749. }
  750. for k in list(dummy_inputs.keys()) +
  751. ["state_ins", SampleBatch.SEQ_LENS]
  752. })
  753. else:
  754. traced = torch.jit.trace(self.model,
  755. (dummy_inputs, state_ins, seq_lens))
  756. file_name = os.path.join(export_dir, "model.pt")
  757. traced.save(file_name)
  758. @override(Policy)
  759. def export_checkpoint(self, export_dir: str) -> None:
  760. raise NotImplementedError
  761. @override(Policy)
  762. @DeveloperAPI
  763. def import_model_from_h5(self, import_file: str) -> None:
  764. """Imports weights into torch model."""
  765. return self.model.import_from_h5(import_file)
  766. @with_lock
  767. def _compute_action_helper(self, input_dict, state_batches, seq_lens,
  768. explore, timestep):
  769. """Shared forward pass logic (w/ and w/o trajectory view API).
  770. Returns:
  771. A tuple consisting of a) actions, b) state_out, c) extra_fetches.
  772. """
  773. explore = explore if explore is not None else self.config["explore"]
  774. timestep = timestep if timestep is not None else self.global_timestep
  775. self._is_recurrent = state_batches is not None and state_batches != []
  776. # Switch to eval mode.
  777. if self.model:
  778. self.model.eval()
  779. if self.action_sampler_fn:
  780. action_dist = dist_inputs = None
  781. actions, logp, state_out = self.action_sampler_fn(
  782. self,
  783. self.model,
  784. input_dict,
  785. state_batches,
  786. explore=explore,
  787. timestep=timestep)
  788. else:
  789. # Call the exploration before_compute_actions hook.
  790. self.exploration.before_compute_actions(
  791. explore=explore, timestep=timestep)
  792. if self.action_distribution_fn:
  793. # Try new action_distribution_fn signature, supporting
  794. # state_batches and seq_lens.
  795. try:
  796. dist_inputs, dist_class, state_out = \
  797. self.action_distribution_fn(
  798. self,
  799. self.model,
  800. input_dict=input_dict,
  801. state_batches=state_batches,
  802. seq_lens=seq_lens,
  803. explore=explore,
  804. timestep=timestep,
  805. is_training=False)
  806. # Trying the old way (to stay backward compatible).
  807. # TODO: Remove in future.
  808. except TypeError as e:
  809. if "positional argument" in e.args[0] or \
  810. "unexpected keyword argument" in e.args[0]:
  811. dist_inputs, dist_class, state_out = \
  812. self.action_distribution_fn(
  813. self,
  814. self.model,
  815. input_dict[SampleBatch.CUR_OBS],
  816. explore=explore,
  817. timestep=timestep,
  818. is_training=False)
  819. else:
  820. raise e
  821. else:
  822. dist_class = self.dist_class
  823. dist_inputs, state_out = self.model(input_dict, state_batches,
  824. seq_lens)
  825. if not (isinstance(dist_class, functools.partial)
  826. or issubclass(dist_class, TorchDistributionWrapper)):
  827. raise ValueError(
  828. "`dist_class` ({}) not a TorchDistributionWrapper "
  829. "subclass! Make sure your `action_distribution_fn` or "
  830. "`make_model_and_action_dist` return a correct "
  831. "distribution class.".format(dist_class.__name__))
  832. action_dist = dist_class(dist_inputs, self.model)
  833. # Get the exploration action from the forward results.
  834. actions, logp = \
  835. self.exploration.get_exploration_action(
  836. action_distribution=action_dist,
  837. timestep=timestep,
  838. explore=explore)
  839. input_dict[SampleBatch.ACTIONS] = actions
  840. # Add default and custom fetches.
  841. extra_fetches = self.extra_action_out(input_dict, state_batches,
  842. self.model, action_dist)
  843. # Action-dist inputs.
  844. if dist_inputs is not None:
  845. extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
  846. # Action-logp and action-prob.
  847. if logp is not None:
  848. extra_fetches[SampleBatch.ACTION_PROB] = \
  849. torch.exp(logp.float())
  850. extra_fetches[SampleBatch.ACTION_LOGP] = logp
  851. # Update our global timestep by the batch size.
  852. self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
  853. return convert_to_numpy((actions, state_out, extra_fetches))
  854. def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
  855. # TODO: (sven): Keep for a while to ensure backward compatibility.
  856. if not isinstance(postprocessed_batch, SampleBatch):
  857. postprocessed_batch = SampleBatch(postprocessed_batch)
  858. postprocessed_batch.set_get_interceptor(
  859. functools.partial(
  860. convert_to_torch_tensor, device=device or self.device))
  861. return postprocessed_batch
  862. def _multi_gpu_parallel_grad_calc(
  863. self, sample_batches: List[SampleBatch]
  864. ) -> List[Tuple[List[TensorType], GradInfoDict]]:
  865. """Performs a parallelized loss and gradient calculation over the batch.
  866. Splits up the given train batch into n shards (n=number of this
  867. Policy's devices) and passes each data shard (in parallel) through
  868. the loss function using the individual devices' models
  869. (self.model_gpu_towers). Then returns each tower's outputs.
  870. Args:
  871. sample_batches: A list of SampleBatch shards to
  872. calculate loss and gradients for.
  873. Returns:
  874. A list (one item per device) of 2-tuples, each with 1) gradient
  875. list and 2) grad info dict.
  876. """
  877. assert len(self.model_gpu_towers) == len(sample_batches)
  878. lock = threading.Lock()
  879. results = {}
  880. grad_enabled = torch.is_grad_enabled()
  881. def _worker(shard_idx, model, sample_batch, device):
  882. torch.set_grad_enabled(grad_enabled)
  883. try:
  884. with NullContextManager(
  885. ) if device.type == "cpu" else torch.cuda.device(device):
  886. loss_out = force_list(
  887. self._loss(self, model, self.dist_class, sample_batch))
  888. # Call Model's custom-loss with Policy loss outputs and
  889. # train_batch.
  890. loss_out = model.custom_loss(loss_out, sample_batch)
  891. assert len(loss_out) == len(self._optimizers)
  892. # Loop through all optimizers.
  893. grad_info = {"allreduce_latency": 0.0}
  894. parameters = list(model.parameters())
  895. all_grads = [None for _ in range(len(parameters))]
  896. for opt_idx, opt in enumerate(self._optimizers):
  897. # Erase gradients in all vars of the tower that this
  898. # optimizer would affect.
  899. param_indices = self.multi_gpu_param_groups[opt_idx]
  900. for param_idx, param in enumerate(parameters):
  901. if param_idx in param_indices and \
  902. param.grad is not None:
  903. param.grad.data.zero_()
  904. # Recompute gradients of loss over all variables.
  905. loss_out[opt_idx].backward(retain_graph=True)
  906. grad_info.update(
  907. self.extra_grad_process(opt, loss_out[opt_idx]))
  908. grads = []
  909. # Note that return values are just references;
  910. # Calling zero_grad would modify the values.
  911. for param_idx, param in enumerate(parameters):
  912. if param_idx in param_indices:
  913. if param.grad is not None:
  914. grads.append(param.grad)
  915. all_grads[param_idx] = param.grad
  916. if self.distributed_world_size:
  917. start = time.time()
  918. if torch.cuda.is_available():
  919. # Sadly, allreduce_coalesced does not work with
  920. # CUDA yet.
  921. for g in grads:
  922. torch.distributed.all_reduce(
  923. g, op=torch.distributed.ReduceOp.SUM)
  924. else:
  925. torch.distributed.all_reduce_coalesced(
  926. grads, op=torch.distributed.ReduceOp.SUM)
  927. for param_group in opt.param_groups:
  928. for p in param_group["params"]:
  929. if p.grad is not None:
  930. p.grad /= self.distributed_world_size
  931. grad_info[
  932. "allreduce_latency"] += time.time() - start
  933. with lock:
  934. results[shard_idx] = (all_grads, grad_info)
  935. except Exception as e:
  936. import traceback
  937. with lock:
  938. results[shard_idx] = (ValueError(
  939. e.args[0] + "\n traceback" + traceback.format_exc() +
  940. "\n" +
  941. "In tower {} on device {}".format(shard_idx, device)),
  942. e)
  943. # Single device (GPU) or fake-GPU case (serialize for better
  944. # debugging).
  945. if len(self.devices) == 1 or self.config["_fake_gpus"]:
  946. for shard_idx, (model, sample_batch, device) in enumerate(
  947. zip(self.model_gpu_towers, sample_batches, self.devices)):
  948. _worker(shard_idx, model, sample_batch, device)
  949. # Raise errors right away for better debugging.
  950. last_result = results[len(results) - 1]
  951. if isinstance(last_result[0], ValueError):
  952. raise last_result[0] from last_result[1]
  953. # Multi device (GPU) case: Parallelize via threads.
  954. else:
  955. threads = [
  956. threading.Thread(
  957. target=_worker,
  958. args=(shard_idx, model, sample_batch, device))
  959. for shard_idx, (model, sample_batch, device) in enumerate(
  960. zip(self.model_gpu_towers, sample_batches, self.devices))
  961. ]
  962. for thread in threads:
  963. thread.start()
  964. for thread in threads:
  965. thread.join()
  966. # Gather all threads' outputs and return.
  967. outputs = []
  968. for shard_idx in range(len(sample_batches)):
  969. output = results[shard_idx]
  970. if isinstance(output[0], Exception):
  971. raise output[0] from output[1]
  972. outputs.append(results[shard_idx])
  973. return outputs
  974. # TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch)
  975. # and for all possible hyperparams, not just lr.
  976. @DeveloperAPI
  977. class LearningRateSchedule:
  978. """Mixin for TorchPolicy that adds a learning rate schedule."""
  979. @DeveloperAPI
  980. def __init__(self, lr, lr_schedule):
  981. self._lr_schedule = None
  982. if lr_schedule is None:
  983. self.cur_lr = lr
  984. else:
  985. self._lr_schedule = PiecewiseSchedule(
  986. lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
  987. self.cur_lr = self._lr_schedule.value(0)
  988. @override(Policy)
  989. def on_global_var_update(self, global_vars):
  990. super().on_global_var_update(global_vars)
  991. if self._lr_schedule:
  992. self.cur_lr = self._lr_schedule.value(global_vars["timestep"])
  993. for opt in self._optimizers:
  994. for p in opt.param_groups:
  995. p["lr"] = self.cur_lr
  996. @DeveloperAPI
  997. class EntropyCoeffSchedule:
  998. """Mixin for TorchPolicy that adds entropy coeff decay."""
  999. @DeveloperAPI
  1000. def __init__(self, entropy_coeff, entropy_coeff_schedule):
  1001. self._entropy_coeff_schedule = None
  1002. if entropy_coeff_schedule is None:
  1003. self.entropy_coeff = entropy_coeff
  1004. else:
  1005. # Allows for custom schedule similar to lr_schedule format
  1006. if isinstance(entropy_coeff_schedule, list):
  1007. self._entropy_coeff_schedule = PiecewiseSchedule(
  1008. entropy_coeff_schedule,
  1009. outside_value=entropy_coeff_schedule[-1][-1],
  1010. framework=None)
  1011. else:
  1012. # Implements previous version but enforces outside_value
  1013. self._entropy_coeff_schedule = PiecewiseSchedule(
  1014. [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
  1015. outside_value=0.0,
  1016. framework=None)
  1017. self.entropy_coeff = self._entropy_coeff_schedule.value(0)
  1018. @override(Policy)
  1019. def on_global_var_update(self, global_vars):
  1020. super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
  1021. if self._entropy_coeff_schedule is not None:
  1022. self.entropy_coeff = self._entropy_coeff_schedule.value(
  1023. global_vars["timestep"])
  1024. @DeveloperAPI
  1025. class DirectStepOptimizer:
  1026. """Typesafe method for indicating `apply_gradients` can directly step the
  1027. optimizers with in-place gradients.
  1028. """
  1029. _instance = None
  1030. def __new__(cls):
  1031. if DirectStepOptimizer._instance is None:
  1032. DirectStepOptimizer._instance = super().__new__(cls)
  1033. return DirectStepOptimizer._instance
  1034. def __eq__(self, other):
  1035. return type(self) == type(other)
  1036. def __repr__(self):
  1037. return "DirectStepOptimizer"
  1038. _directStepOptimizerSingleton = DirectStepOptimizer()