torch_policy.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225
  1. import copy
  2. import functools
  3. import logging
  4. import math
  5. import os
  6. import threading
  7. import time
  8. from typing import (
  9. TYPE_CHECKING,
  10. Any,
  11. Callable,
  12. Dict,
  13. List,
  14. Optional,
  15. Set,
  16. Tuple,
  17. Type,
  18. Union,
  19. )
  20. import gymnasium as gym
  21. import numpy as np
  22. import tree # pip install dm_tree
  23. import ray
  24. from ray.rllib.models.catalog import ModelCatalog
  25. from ray.rllib.models.modelv2 import ModelV2
  26. from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
  27. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  28. from ray.rllib.policy.policy import Policy, PolicyState
  29. from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
  30. from ray.rllib.policy.sample_batch import SampleBatch
  31. from ray.rllib.utils import NullContextManager, force_list
  32. from ray.rllib.utils.annotations import DeveloperAPI, override
  33. from ray.rllib.utils.error import ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL
  34. from ray.rllib.utils.framework import try_import_torch
  35. from ray.rllib.utils.metrics import (
  36. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
  37. NUM_AGENT_STEPS_TRAINED,
  38. NUM_GRAD_UPDATES_LIFETIME,
  39. )
  40. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  41. from ray.rllib.utils.numpy import convert_to_numpy
  42. from ray.rllib.utils.spaces.space_utils import normalize_action
  43. from ray.rllib.utils.threading import with_lock
  44. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  45. from ray.rllib.utils.typing import (
  46. AlgorithmConfigDict,
  47. GradInfoDict,
  48. ModelGradients,
  49. ModelWeights,
  50. TensorStructType,
  51. TensorType,
  52. )
  53. if TYPE_CHECKING:
  54. from ray.rllib.evaluation import Episode # noqa
  55. torch, nn = try_import_torch()
  56. logger = logging.getLogger(__name__)
  57. @DeveloperAPI
  58. class TorchPolicy(Policy):
  59. """PyTorch specific Policy class to use with RLlib."""
  60. @DeveloperAPI
  61. def __init__(
  62. self,
  63. observation_space: gym.spaces.Space,
  64. action_space: gym.spaces.Space,
  65. config: AlgorithmConfigDict,
  66. *,
  67. model: Optional[TorchModelV2] = None,
  68. loss: Optional[
  69. Callable[
  70. [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
  71. Union[TensorType, List[TensorType]],
  72. ]
  73. ] = None,
  74. action_distribution_class: Optional[Type[TorchDistributionWrapper]] = None,
  75. action_sampler_fn: Optional[
  76. Callable[
  77. [TensorType, List[TensorType]],
  78. Union[
  79. Tuple[TensorType, TensorType, List[TensorType]],
  80. Tuple[TensorType, TensorType, TensorType, List[TensorType]],
  81. ],
  82. ]
  83. ] = None,
  84. action_distribution_fn: Optional[
  85. Callable[
  86. [Policy, ModelV2, TensorType, TensorType, TensorType],
  87. Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]],
  88. ]
  89. ] = None,
  90. max_seq_len: int = 20,
  91. get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
  92. ):
  93. """Initializes a TorchPolicy instance.
  94. Args:
  95. observation_space: Observation space of the policy.
  96. action_space: Action space of the policy.
  97. config: The Policy's config dict.
  98. model: PyTorch policy module. Given observations as
  99. input, this module must return a list of outputs where the
  100. first item is action logits, and the rest can be any value.
  101. loss: Callable that returns one or more (a list of) scalar loss
  102. terms.
  103. action_distribution_class: Class for a torch action distribution.
  104. action_sampler_fn: A callable returning either a sampled action,
  105. its log-likelihood and updated state or a sampled action, its
  106. log-likelihood, updated state and action distribution inputs
  107. given Policy, ModelV2, input_dict, state batches (optional),
  108. explore, and timestep. Provide `action_sampler_fn` if you would
  109. like to have full control over the action computation step,
  110. including the model forward pass, possible sampling from a
  111. distribution, and exploration logic.
  112. Note: If `action_sampler_fn` is given, `action_distribution_fn`
  113. must be None. If both `action_sampler_fn` and
  114. `action_distribution_fn` are None, RLlib will simply pass
  115. inputs through `self.model` to get distribution inputs, create
  116. the distribution object, sample from it, and apply some
  117. exploration logic to the results.
  118. The callable takes as inputs: Policy, ModelV2, input_dict
  119. (SampleBatch), state_batches (optional), explore, and timestep.
  120. action_distribution_fn: A callable returning distribution inputs
  121. (parameters), a dist-class to generate an action distribution
  122. object from, and internal-state outputs (or an empty list if
  123. not applicable).
  124. Provide `action_distribution_fn` if you would like to only
  125. customize the model forward pass call. The resulting
  126. distribution parameters are then used by RLlib to create a
  127. distribution object, sample from it, and execute any
  128. exploration logic.
  129. Note: If `action_distribution_fn` is given, `action_sampler_fn`
  130. must be None. If both `action_sampler_fn` and
  131. `action_distribution_fn` are None, RLlib will simply pass
  132. inputs through `self.model` to get distribution inputs, create
  133. the distribution object, sample from it, and apply some
  134. exploration logic to the results.
  135. The callable takes as inputs: Policy, ModelV2, ModelInputDict,
  136. explore, timestep, is_training.
  137. max_seq_len: Max sequence length for LSTM training.
  138. get_batch_divisibility_req: Optional callable that returns the
  139. divisibility requirement for sample batches given the Policy.
  140. """
  141. self.framework = config["framework"] = "torch"
  142. self._loss_initialized = False
  143. super().__init__(observation_space, action_space, config)
  144. # Create multi-GPU model towers, if necessary.
  145. # - The central main model will be stored under self.model, residing
  146. # on self.device (normally, a CPU).
  147. # - Each GPU will have a copy of that model under
  148. # self.model_gpu_towers, matching the devices in self.devices.
  149. # - Parallelization is done by splitting the train batch and passing
  150. # it through the model copies in parallel, then averaging over the
  151. # resulting gradients, applying these averages on the main model and
  152. # updating all towers' weights from the main model.
  153. # - In case of just one device (1 (fake or real) GPU or 1 CPU), no
  154. # parallelization will be done.
  155. # If no Model is provided, build a default one here.
  156. if model is None:
  157. dist_class, logit_dim = ModelCatalog.get_action_dist(
  158. action_space, self.config["model"], framework=self.framework
  159. )
  160. model = ModelCatalog.get_model_v2(
  161. obs_space=self.observation_space,
  162. action_space=self.action_space,
  163. num_outputs=logit_dim,
  164. model_config=self.config["model"],
  165. framework=self.framework,
  166. )
  167. if action_distribution_class is None:
  168. action_distribution_class = dist_class
  169. # Get devices to build the graph on.
  170. num_gpus = self._get_num_gpus_for_policy()
  171. gpu_ids = list(range(torch.cuda.device_count()))
  172. logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
  173. # Place on one or more CPU(s) when either:
  174. # - Fake GPU mode.
  175. # - num_gpus=0 (either set by user or we are in local_mode=True).
  176. # - No GPUs available.
  177. if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
  178. self.device = torch.device("cpu")
  179. self.devices = [self.device for _ in range(int(math.ceil(num_gpus)) or 1)]
  180. self.model_gpu_towers = [
  181. model if i == 0 else copy.deepcopy(model)
  182. for i in range(int(math.ceil(num_gpus)) or 1)
  183. ]
  184. if hasattr(self, "target_model"):
  185. self.target_models = {
  186. m: self.target_model for m in self.model_gpu_towers
  187. }
  188. self.model = model
  189. # Place on one or more actual GPU(s), when:
  190. # - num_gpus > 0 (set by user) AND
  191. # - local_mode=False AND
  192. # - actual GPUs available AND
  193. # - non-fake GPU mode.
  194. else:
  195. # We are a remote worker (WORKER_MODE=1):
  196. # GPUs should be assigned to us by ray.
  197. if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
  198. gpu_ids = ray.get_gpu_ids()
  199. if len(gpu_ids) < num_gpus:
  200. raise ValueError(
  201. "TorchPolicy was not able to find enough GPU IDs! Found "
  202. f"{gpu_ids}, but num_gpus={num_gpus}."
  203. )
  204. self.devices = [
  205. torch.device("cuda:{}".format(i))
  206. for i, id_ in enumerate(gpu_ids)
  207. if i < num_gpus
  208. ]
  209. self.device = self.devices[0]
  210. ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
  211. self.model_gpu_towers = []
  212. for i, _ in enumerate(ids):
  213. model_copy = copy.deepcopy(model)
  214. self.model_gpu_towers.append(model_copy.to(self.devices[i]))
  215. if hasattr(self, "target_model"):
  216. self.target_models = {
  217. m: copy.deepcopy(self.target_model).to(self.devices[i])
  218. for i, m in enumerate(self.model_gpu_towers)
  219. }
  220. self.model = self.model_gpu_towers[0]
  221. # Lock used for locking some methods on the object-level.
  222. # This prevents possible race conditions when calling the model
  223. # first, then its value function (e.g. in a loss function), in
  224. # between of which another model call is made (e.g. to compute an
  225. # action).
  226. self._lock = threading.RLock()
  227. self._state_inputs = self.model.get_initial_state()
  228. self._is_recurrent = len(self._state_inputs) > 0
  229. # Auto-update model's inference view requirements, if recurrent.
  230. self._update_model_view_requirements_from_init_state()
  231. # Combine view_requirements for Model and Policy.
  232. self.view_requirements.update(self.model.view_requirements)
  233. self.exploration = self._create_exploration()
  234. self.unwrapped_model = model # used to support DistributedDataParallel
  235. # To ensure backward compatibility:
  236. # Old way: If `loss` provided here, use as-is (as a function).
  237. if loss is not None:
  238. self._loss = loss
  239. # New way: Convert the overridden `self.loss` into a plain function,
  240. # so it can be called the same way as `loss` would be, ensuring
  241. # backward compatibility.
  242. elif self.loss.__func__.__qualname__ != "Policy.loss":
  243. self._loss = self.loss.__func__
  244. # `loss` not provided nor overridden from Policy -> Set to None.
  245. else:
  246. self._loss = None
  247. self._optimizers = force_list(self.optimizer())
  248. # Store, which params (by index within the model's list of
  249. # parameters) should be updated per optimizer.
  250. # Maps optimizer idx to set or param indices.
  251. self.multi_gpu_param_groups: List[Set[int]] = []
  252. main_params = {p: i for i, p in enumerate(self.model.parameters())}
  253. for o in self._optimizers:
  254. param_indices = []
  255. for pg_idx, pg in enumerate(o.param_groups):
  256. for p in pg["params"]:
  257. param_indices.append(main_params[p])
  258. self.multi_gpu_param_groups.append(set(param_indices))
  259. # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
  260. # one with m towers (num_gpus).
  261. num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
  262. self._loaded_batches = [[] for _ in range(num_buffers)]
  263. self.dist_class = action_distribution_class
  264. self.action_sampler_fn = action_sampler_fn
  265. self.action_distribution_fn = action_distribution_fn
  266. # If set, means we are using distributed allreduce during learning.
  267. self.distributed_world_size = None
  268. self.max_seq_len = max_seq_len
  269. self.batch_divisibility_req = (
  270. get_batch_divisibility_req(self)
  271. if callable(get_batch_divisibility_req)
  272. else (get_batch_divisibility_req or 1)
  273. )
  274. @override(Policy)
  275. def compute_actions_from_input_dict(
  276. self,
  277. input_dict: Dict[str, TensorType],
  278. explore: bool = None,
  279. timestep: Optional[int] = None,
  280. **kwargs,
  281. ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
  282. with torch.no_grad():
  283. # Pass lazy (torch) tensor dict to Model as `input_dict`.
  284. input_dict = self._lazy_tensor_dict(input_dict)
  285. input_dict.set_training(True)
  286. # Pack internal state inputs into (separate) list.
  287. state_batches = [
  288. input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
  289. ]
  290. # Calculate RNN sequence lengths.
  291. seq_lens = (
  292. torch.tensor(
  293. [1] * len(state_batches[0]),
  294. dtype=torch.long,
  295. device=state_batches[0].device,
  296. )
  297. if state_batches
  298. else None
  299. )
  300. return self._compute_action_helper(
  301. input_dict, state_batches, seq_lens, explore, timestep
  302. )
  303. @override(Policy)
  304. @DeveloperAPI
  305. def compute_actions(
  306. self,
  307. obs_batch: Union[List[TensorStructType], TensorStructType],
  308. state_batches: Optional[List[TensorType]] = None,
  309. prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
  310. prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
  311. info_batch: Optional[Dict[str, list]] = None,
  312. episodes: Optional[List["Episode"]] = None,
  313. explore: Optional[bool] = None,
  314. timestep: Optional[int] = None,
  315. **kwargs,
  316. ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
  317. with torch.no_grad():
  318. seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
  319. input_dict = self._lazy_tensor_dict(
  320. {
  321. SampleBatch.CUR_OBS: obs_batch,
  322. "is_training": False,
  323. }
  324. )
  325. if prev_action_batch is not None:
  326. input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(prev_action_batch)
  327. if prev_reward_batch is not None:
  328. input_dict[SampleBatch.PREV_REWARDS] = np.asarray(prev_reward_batch)
  329. state_batches = [
  330. convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
  331. ]
  332. return self._compute_action_helper(
  333. input_dict, state_batches, seq_lens, explore, timestep
  334. )
  335. @with_lock
  336. @override(Policy)
  337. @DeveloperAPI
  338. def compute_log_likelihoods(
  339. self,
  340. actions: Union[List[TensorStructType], TensorStructType],
  341. obs_batch: Union[List[TensorStructType], TensorStructType],
  342. state_batches: Optional[List[TensorType]] = None,
  343. prev_action_batch: Optional[
  344. Union[List[TensorStructType], TensorStructType]
  345. ] = None,
  346. prev_reward_batch: Optional[
  347. Union[List[TensorStructType], TensorStructType]
  348. ] = None,
  349. actions_normalized: bool = True,
  350. **kwargs,
  351. ) -> TensorType:
  352. if self.action_sampler_fn and self.action_distribution_fn is None:
  353. raise ValueError(
  354. "Cannot compute log-prob/likelihood w/o an "
  355. "`action_distribution_fn` and a provided "
  356. "`action_sampler_fn`!"
  357. )
  358. with torch.no_grad():
  359. input_dict = self._lazy_tensor_dict(
  360. {SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
  361. )
  362. if prev_action_batch is not None:
  363. input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
  364. if prev_reward_batch is not None:
  365. input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
  366. seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
  367. state_batches = [
  368. convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
  369. ]
  370. # Exploration hook before each forward pass.
  371. self.exploration.before_compute_actions(explore=False)
  372. # Action dist class and inputs are generated via custom function.
  373. if self.action_distribution_fn:
  374. # Try new action_distribution_fn signature, supporting
  375. # state_batches and seq_lens.
  376. try:
  377. dist_inputs, dist_class, state_out = self.action_distribution_fn(
  378. self,
  379. self.model,
  380. input_dict=input_dict,
  381. state_batches=state_batches,
  382. seq_lens=seq_lens,
  383. explore=False,
  384. is_training=False,
  385. )
  386. # Trying the old way (to stay backward compatible).
  387. # TODO: Remove in future.
  388. except TypeError as e:
  389. if (
  390. "positional argument" in e.args[0]
  391. or "unexpected keyword argument" in e.args[0]
  392. ):
  393. dist_inputs, dist_class, _ = self.action_distribution_fn(
  394. policy=self,
  395. model=self.model,
  396. obs_batch=input_dict[SampleBatch.CUR_OBS],
  397. explore=False,
  398. is_training=False,
  399. )
  400. else:
  401. raise e
  402. # Default action-dist inputs calculation.
  403. else:
  404. dist_class = self.dist_class
  405. dist_inputs, _ = self.model(input_dict, state_batches, seq_lens)
  406. action_dist = dist_class(dist_inputs, self.model)
  407. # Normalize actions if necessary.
  408. actions = input_dict[SampleBatch.ACTIONS]
  409. if not actions_normalized and self.config["normalize_actions"]:
  410. actions = normalize_action(actions, self.action_space_struct)
  411. log_likelihoods = action_dist.logp(actions)
  412. return log_likelihoods
  413. @with_lock
  414. @override(Policy)
  415. @DeveloperAPI
  416. def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
  417. # Set Model to train mode.
  418. if self.model:
  419. self.model.train()
  420. # Callback handling.
  421. learn_stats = {}
  422. self.callbacks.on_learn_on_batch(
  423. policy=self, train_batch=postprocessed_batch, result=learn_stats
  424. )
  425. # Compute gradients (will calculate all losses and `backward()`
  426. # them to get the grads).
  427. grads, fetches = self.compute_gradients(postprocessed_batch)
  428. # Step the optimizers.
  429. self.apply_gradients(_directStepOptimizerSingleton)
  430. self.num_grad_updates += 1
  431. if self.model:
  432. fetches["model"] = self.model.metrics()
  433. fetches.update(
  434. {
  435. "custom_metrics": learn_stats,
  436. NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
  437. NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
  438. # -1, b/c we have to measure this diff before we do the update above.
  439. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
  440. self.num_grad_updates
  441. - 1
  442. - (postprocessed_batch.num_grad_updates or 0)
  443. ),
  444. }
  445. )
  446. return fetches
  447. @override(Policy)
  448. @DeveloperAPI
  449. def load_batch_into_buffer(
  450. self,
  451. batch: SampleBatch,
  452. buffer_index: int = 0,
  453. ) -> int:
  454. # Set the is_training flag of the batch.
  455. batch.set_training(True)
  456. # Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
  457. if len(self.devices) == 1 and self.devices[0].type == "cpu":
  458. assert buffer_index == 0
  459. pad_batch_to_sequences_of_same_size(
  460. batch=batch,
  461. max_seq_len=self.max_seq_len,
  462. shuffle=False,
  463. batch_divisibility_req=self.batch_divisibility_req,
  464. view_requirements=self.view_requirements,
  465. )
  466. self._lazy_tensor_dict(batch)
  467. self._loaded_batches[0] = [batch]
  468. return len(batch)
  469. # Batch (len=28, seq-lens=[4, 7, 4, 10, 3]):
  470. # 0123 0123456 0123 0123456789ABC
  471. # 1) split into n per-GPU sub batches (n=2).
  472. # [0123 0123456] [012] [3 0123456789 ABC]
  473. # (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3])
  474. slices = batch.timeslices(num_slices=len(self.devices))
  475. # 2) zero-padding (max-seq-len=10).
  476. # - [0123000000 0123456000 0120000000]
  477. # - [3000000000 0123456789 ABC0000000]
  478. for slice in slices:
  479. pad_batch_to_sequences_of_same_size(
  480. batch=slice,
  481. max_seq_len=self.max_seq_len,
  482. shuffle=False,
  483. batch_divisibility_req=self.batch_divisibility_req,
  484. view_requirements=self.view_requirements,
  485. )
  486. # 3) Load splits into the given buffer (consisting of n GPUs).
  487. slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)]
  488. self._loaded_batches[buffer_index] = slices
  489. # Return loaded samples per-device.
  490. return len(slices[0])
  491. @override(Policy)
  492. @DeveloperAPI
  493. def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
  494. if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
  495. assert buffer_index == 0
  496. return sum(len(b) for b in self._loaded_batches[buffer_index])
  497. @override(Policy)
  498. @DeveloperAPI
  499. def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
  500. if not self._loaded_batches[buffer_index]:
  501. raise ValueError(
  502. "Must call Policy.load_batch_into_buffer() before "
  503. "Policy.learn_on_loaded_batch()!"
  504. )
  505. # Get the correct slice of the already loaded batch to use,
  506. # based on offset and batch size.
  507. device_batch_size = self.config.get(
  508. "sgd_minibatch_size", self.config["train_batch_size"]
  509. ) // len(self.devices)
  510. # Set Model to train mode.
  511. if self.model_gpu_towers:
  512. for t in self.model_gpu_towers:
  513. t.train()
  514. # Shortcut for 1 CPU only: Batch should already be stored in
  515. # `self._loaded_batches`.
  516. if len(self.devices) == 1 and self.devices[0].type == "cpu":
  517. assert buffer_index == 0
  518. if device_batch_size >= len(self._loaded_batches[0][0]):
  519. batch = self._loaded_batches[0][0]
  520. else:
  521. batch = self._loaded_batches[0][0][offset : offset + device_batch_size]
  522. return self.learn_on_batch(batch)
  523. if len(self.devices) > 1:
  524. # Copy weights of main model (tower-0) to all other towers.
  525. state_dict = self.model.state_dict()
  526. # Just making sure tower-0 is really the same as self.model.
  527. assert self.model_gpu_towers[0] is self.model
  528. for tower in self.model_gpu_towers[1:]:
  529. tower.load_state_dict(state_dict)
  530. if device_batch_size >= sum(len(s) for s in self._loaded_batches[buffer_index]):
  531. device_batches = self._loaded_batches[buffer_index]
  532. else:
  533. device_batches = [
  534. b[offset : offset + device_batch_size]
  535. for b in self._loaded_batches[buffer_index]
  536. ]
  537. # Callback handling.
  538. batch_fetches = {}
  539. for i, batch in enumerate(device_batches):
  540. custom_metrics = {}
  541. self.callbacks.on_learn_on_batch(
  542. policy=self, train_batch=batch, result=custom_metrics
  543. )
  544. batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics}
  545. # Do the (maybe parallelized) gradient calculation step.
  546. tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
  547. # Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
  548. all_grads = []
  549. for i in range(len(tower_outputs[0][0])):
  550. if tower_outputs[0][0][i] is not None:
  551. all_grads.append(
  552. torch.mean(
  553. torch.stack([t[0][i].to(self.device) for t in tower_outputs]),
  554. dim=0,
  555. )
  556. )
  557. else:
  558. all_grads.append(None)
  559. # Set main model's grads to mean-reduced values.
  560. for i, p in enumerate(self.model.parameters()):
  561. p.grad = all_grads[i]
  562. self.apply_gradients(_directStepOptimizerSingleton)
  563. self.num_grad_updates += 1
  564. for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)):
  565. batch_fetches[f"tower_{i}"].update(
  566. {
  567. LEARNER_STATS_KEY: self.extra_grad_info(batch),
  568. "model": model.metrics(),
  569. NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
  570. # -1, b/c we have to measure this diff before we do the update
  571. # above.
  572. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
  573. self.num_grad_updates - 1 - (batch.num_grad_updates or 0)
  574. ),
  575. }
  576. )
  577. batch_fetches.update(self.extra_compute_grad_fetches())
  578. return batch_fetches
  579. @with_lock
  580. @override(Policy)
  581. @DeveloperAPI
  582. def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients:
  583. assert len(self.devices) == 1
  584. # If not done yet, see whether we have to zero-pad this batch.
  585. if not postprocessed_batch.zero_padded:
  586. pad_batch_to_sequences_of_same_size(
  587. batch=postprocessed_batch,
  588. max_seq_len=self.max_seq_len,
  589. shuffle=False,
  590. batch_divisibility_req=self.batch_divisibility_req,
  591. view_requirements=self.view_requirements,
  592. )
  593. postprocessed_batch.set_training(True)
  594. self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
  595. # Do the (maybe parallelized) gradient calculation step.
  596. tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
  597. all_grads, grad_info = tower_outputs[0]
  598. grad_info["allreduce_latency"] /= len(self._optimizers)
  599. grad_info.update(self.extra_grad_info(postprocessed_batch))
  600. fetches = self.extra_compute_grad_fetches()
  601. return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
  602. @override(Policy)
  603. @DeveloperAPI
  604. def apply_gradients(self, gradients: ModelGradients) -> None:
  605. if gradients == _directStepOptimizerSingleton:
  606. for i, opt in enumerate(self._optimizers):
  607. opt.step()
  608. else:
  609. # TODO(sven): Not supported for multiple optimizers yet.
  610. assert len(self._optimizers) == 1
  611. for g, p in zip(gradients, self.model.parameters()):
  612. if g is not None:
  613. if torch.is_tensor(g):
  614. p.grad = g.to(self.device)
  615. else:
  616. p.grad = torch.from_numpy(g).to(self.device)
  617. self._optimizers[0].step()
  618. @DeveloperAPI
  619. def get_tower_stats(self, stats_name: str) -> List[TensorStructType]:
  620. """Returns list of per-tower stats, copied to this Policy's device.
  621. Args:
  622. stats_name: The name of the stats to average over (this str
  623. must exist as a key inside each tower's `tower_stats` dict).
  624. Returns:
  625. The list of stats tensor (structs) of all towers, copied to this
  626. Policy's device.
  627. Raises:
  628. AssertionError: If the `stats_name` cannot be found in any one
  629. of the tower's `tower_stats` dicts.
  630. """
  631. data = []
  632. for tower in self.model_gpu_towers:
  633. if stats_name in tower.tower_stats:
  634. data.append(
  635. tree.map_structure(
  636. lambda s: s.to(self.device), tower.tower_stats[stats_name]
  637. )
  638. )
  639. assert len(data) > 0, (
  640. f"Stats `{stats_name}` not found in any of the towers (you have "
  641. f"{len(self.model_gpu_towers)} towers in total)! Make "
  642. "sure you call the loss function on at least one of the towers."
  643. )
  644. return data
  645. @override(Policy)
  646. @DeveloperAPI
  647. def get_weights(self) -> ModelWeights:
  648. return {k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items()}
  649. @override(Policy)
  650. @DeveloperAPI
  651. def set_weights(self, weights: ModelWeights) -> None:
  652. weights = convert_to_torch_tensor(weights, device=self.device)
  653. self.model.load_state_dict(weights)
  654. @override(Policy)
  655. @DeveloperAPI
  656. def is_recurrent(self) -> bool:
  657. return self._is_recurrent
  658. @override(Policy)
  659. @DeveloperAPI
  660. def num_state_tensors(self) -> int:
  661. return len(self.model.get_initial_state())
  662. @override(Policy)
  663. @DeveloperAPI
  664. def get_initial_state(self) -> List[TensorType]:
  665. return [s.detach().cpu().numpy() for s in self.model.get_initial_state()]
  666. @override(Policy)
  667. @DeveloperAPI
  668. def get_state(self) -> PolicyState:
  669. state = super().get_state()
  670. state["_optimizer_variables"] = []
  671. for i, o in enumerate(self._optimizers):
  672. optim_state_dict = convert_to_numpy(o.state_dict())
  673. state["_optimizer_variables"].append(optim_state_dict)
  674. # Add exploration state.
  675. if not self.config.get("_enable_rl_module_api", False) and self.exploration:
  676. # This is not compatible with RLModules, which have a method
  677. # `forward_exploration` to specify custom exploration behavior.
  678. state["_exploration_state"] = self.exploration.get_state()
  679. return state
  680. @override(Policy)
  681. @DeveloperAPI
  682. def set_state(self, state: PolicyState) -> None:
  683. # Set optimizer vars first.
  684. optimizer_vars = state.get("_optimizer_variables", None)
  685. if optimizer_vars:
  686. assert len(optimizer_vars) == len(self._optimizers)
  687. for o, s in zip(self._optimizers, optimizer_vars):
  688. # Torch optimizer param_groups include things like beta, etc. These
  689. # parameters should be left as scalar and not converted to tensors.
  690. # otherwise, torch.optim.step() will start to complain.
  691. optim_state_dict = {"param_groups": s["param_groups"]}
  692. optim_state_dict["state"] = convert_to_torch_tensor(
  693. s["state"], device=self.device
  694. )
  695. o.load_state_dict(optim_state_dict)
  696. # Set exploration's state.
  697. if hasattr(self, "exploration") and "_exploration_state" in state:
  698. self.exploration.set_state(state=state["_exploration_state"])
  699. # Restore global timestep.
  700. self.global_timestep = state["global_timestep"]
  701. # Then the Policy's (NN) weights and connectors.
  702. super().set_state(state)
  703. @DeveloperAPI
  704. def extra_grad_process(
  705. self, optimizer: "torch.optim.Optimizer", loss: TensorType
  706. ) -> Dict[str, TensorType]:
  707. """Called after each optimizer.zero_grad() + loss.backward() call.
  708. Called for each self._optimizers/loss-value pair.
  709. Allows for gradient processing before optimizer.step() is called.
  710. E.g. for gradient clipping.
  711. Args:
  712. optimizer: A torch optimizer object.
  713. loss: The loss tensor associated with the optimizer.
  714. Returns:
  715. An dict with information on the gradient processing step.
  716. """
  717. return {}
  718. @DeveloperAPI
  719. def extra_compute_grad_fetches(self) -> Dict[str, Any]:
  720. """Extra values to fetch and return from compute_gradients().
  721. Returns:
  722. Extra fetch dict to be added to the fetch dict of the
  723. `compute_gradients` call.
  724. """
  725. return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
  726. @DeveloperAPI
  727. def extra_action_out(
  728. self,
  729. input_dict: Dict[str, TensorType],
  730. state_batches: List[TensorType],
  731. model: TorchModelV2,
  732. action_dist: TorchDistributionWrapper,
  733. ) -> Dict[str, TensorType]:
  734. """Returns dict of extra info to include in experience batch.
  735. Args:
  736. input_dict: Dict of model input tensors.
  737. state_batches: List of state tensors.
  738. model: Reference to the model object.
  739. action_dist: Torch action dist object
  740. to get log-probs (e.g. for already sampled actions).
  741. Returns:
  742. Extra outputs to return in a `compute_actions_from_input_dict()`
  743. call (3rd return value).
  744. """
  745. return {}
  746. @DeveloperAPI
  747. def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
  748. """Return dict of extra grad info.
  749. Args:
  750. train_batch: The training batch for which to produce
  751. extra grad info for.
  752. Returns:
  753. The info dict carrying grad info per str key.
  754. """
  755. return {}
  756. @DeveloperAPI
  757. def optimizer(
  758. self,
  759. ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
  760. """Custom the local PyTorch optimizer(s) to use.
  761. Returns:
  762. The local PyTorch optimizer(s) to use for this Policy.
  763. """
  764. if hasattr(self, "config"):
  765. optimizers = [
  766. torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
  767. ]
  768. else:
  769. optimizers = [torch.optim.Adam(self.model.parameters())]
  770. if self.exploration:
  771. optimizers = self.exploration.get_exploration_optimizer(optimizers)
  772. return optimizers
  773. @override(Policy)
  774. @DeveloperAPI
  775. def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
  776. """Exports the Policy's Model to local directory for serving.
  777. Creates a TorchScript model and saves it.
  778. Args:
  779. export_dir: Local writable directory or filename.
  780. onnx: If given, will export model in ONNX format. The
  781. value of this parameter set the ONNX OpSet version to use.
  782. """
  783. os.makedirs(export_dir, exist_ok=True)
  784. if onnx:
  785. self._lazy_tensor_dict(self._dummy_batch)
  786. # Provide dummy state inputs if not an RNN (torch cannot jit with
  787. # returned empty internal states list).
  788. if "state_in_0" not in self._dummy_batch:
  789. self._dummy_batch["state_in_0"] = self._dummy_batch[
  790. SampleBatch.SEQ_LENS
  791. ] = np.array([1.0])
  792. seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS]
  793. state_ins = []
  794. i = 0
  795. while "state_in_{}".format(i) in self._dummy_batch:
  796. state_ins.append(self._dummy_batch["state_in_{}".format(i)])
  797. i += 1
  798. dummy_inputs = {
  799. k: self._dummy_batch[k]
  800. for k in self._dummy_batch.keys()
  801. if k != "is_training"
  802. }
  803. file_name = os.path.join(export_dir, "model.onnx")
  804. torch.onnx.export(
  805. self.model,
  806. (dummy_inputs, state_ins, seq_lens),
  807. file_name,
  808. export_params=True,
  809. opset_version=onnx,
  810. do_constant_folding=True,
  811. input_names=list(dummy_inputs.keys())
  812. + ["state_ins", SampleBatch.SEQ_LENS],
  813. output_names=["output", "state_outs"],
  814. dynamic_axes={
  815. k: {0: "batch_size"}
  816. for k in list(dummy_inputs.keys())
  817. + ["state_ins", SampleBatch.SEQ_LENS]
  818. },
  819. )
  820. # Save the torch.Model (architecture and weights, so it can be retrieved
  821. # w/o access to the original (custom) Model or Policy code).
  822. else:
  823. filename = os.path.join(export_dir, "model.pt")
  824. try:
  825. torch.save(self.model, f=filename)
  826. except Exception:
  827. if os.path.exists(filename):
  828. os.remove(filename)
  829. logger.warning(ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL)
  830. @override(Policy)
  831. @DeveloperAPI
  832. def import_model_from_h5(self, import_file: str) -> None:
  833. """Imports weights into torch model."""
  834. return self.model.import_from_h5(import_file)
  835. @with_lock
  836. def _compute_action_helper(
  837. self, input_dict, state_batches, seq_lens, explore, timestep
  838. ):
  839. """Shared forward pass logic (w/ and w/o trajectory view API).
  840. Returns:
  841. A tuple consisting of a) actions, b) state_out, c) extra_fetches.
  842. """
  843. explore = explore if explore is not None else self.config["explore"]
  844. timestep = timestep if timestep is not None else self.global_timestep
  845. self._is_recurrent = state_batches is not None and state_batches != []
  846. # Switch to eval mode.
  847. if self.model:
  848. self.model.eval()
  849. if self.action_sampler_fn:
  850. action_dist = dist_inputs = None
  851. action_sampler_outputs = self.action_sampler_fn(
  852. self,
  853. self.model,
  854. input_dict,
  855. state_batches,
  856. explore=explore,
  857. timestep=timestep,
  858. )
  859. if len(action_sampler_outputs) == 4:
  860. actions, logp, dist_inputs, state_out = action_sampler_outputs
  861. else:
  862. actions, logp, state_out = action_sampler_outputs
  863. else:
  864. # Call the exploration before_compute_actions hook.
  865. self.exploration.before_compute_actions(explore=explore, timestep=timestep)
  866. if self.action_distribution_fn:
  867. # Try new action_distribution_fn signature, supporting
  868. # state_batches and seq_lens.
  869. try:
  870. dist_inputs, dist_class, state_out = self.action_distribution_fn(
  871. self,
  872. self.model,
  873. input_dict=input_dict,
  874. state_batches=state_batches,
  875. seq_lens=seq_lens,
  876. explore=explore,
  877. timestep=timestep,
  878. is_training=False,
  879. )
  880. # Trying the old way (to stay backward compatible).
  881. # TODO: Remove in future.
  882. except TypeError as e:
  883. if (
  884. "positional argument" in e.args[0]
  885. or "unexpected keyword argument" in e.args[0]
  886. ):
  887. (
  888. dist_inputs,
  889. dist_class,
  890. state_out,
  891. ) = self.action_distribution_fn(
  892. self,
  893. self.model,
  894. input_dict[SampleBatch.CUR_OBS],
  895. explore=explore,
  896. timestep=timestep,
  897. is_training=False,
  898. )
  899. else:
  900. raise e
  901. else:
  902. dist_class = self.dist_class
  903. dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
  904. if not (
  905. isinstance(dist_class, functools.partial)
  906. or issubclass(dist_class, TorchDistributionWrapper)
  907. ):
  908. raise ValueError(
  909. "`dist_class` ({}) not a TorchDistributionWrapper "
  910. "subclass! Make sure your `action_distribution_fn` or "
  911. "`make_model_and_action_dist` return a correct "
  912. "distribution class.".format(dist_class.__name__)
  913. )
  914. action_dist = dist_class(dist_inputs, self.model)
  915. # Get the exploration action from the forward results.
  916. actions, logp = self.exploration.get_exploration_action(
  917. action_distribution=action_dist, timestep=timestep, explore=explore
  918. )
  919. input_dict[SampleBatch.ACTIONS] = actions
  920. # Add default and custom fetches.
  921. extra_fetches = self.extra_action_out(
  922. input_dict, state_batches, self.model, action_dist
  923. )
  924. # Action-dist inputs.
  925. if dist_inputs is not None:
  926. extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
  927. # Action-logp and action-prob.
  928. if logp is not None:
  929. extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float())
  930. extra_fetches[SampleBatch.ACTION_LOGP] = logp
  931. # Update our global timestep by the batch size.
  932. self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
  933. return convert_to_numpy((actions, state_out, extra_fetches))
  934. def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
  935. # TODO: (sven): Keep for a while to ensure backward compatibility.
  936. if not isinstance(postprocessed_batch, SampleBatch):
  937. postprocessed_batch = SampleBatch(postprocessed_batch)
  938. postprocessed_batch.set_get_interceptor(
  939. functools.partial(convert_to_torch_tensor, device=device or self.device)
  940. )
  941. return postprocessed_batch
  942. def _multi_gpu_parallel_grad_calc(
  943. self, sample_batches: List[SampleBatch]
  944. ) -> List[Tuple[List[TensorType], GradInfoDict]]:
  945. """Performs a parallelized loss and gradient calculation over the batch.
  946. Splits up the given train batch into n shards (n=number of this
  947. Policy's devices) and passes each data shard (in parallel) through
  948. the loss function using the individual devices' models
  949. (self.model_gpu_towers). Then returns each tower's outputs.
  950. Args:
  951. sample_batches: A list of SampleBatch shards to
  952. calculate loss and gradients for.
  953. Returns:
  954. A list (one item per device) of 2-tuples, each with 1) gradient
  955. list and 2) grad info dict.
  956. """
  957. assert len(self.model_gpu_towers) == len(sample_batches)
  958. lock = threading.Lock()
  959. results = {}
  960. grad_enabled = torch.is_grad_enabled()
  961. def _worker(shard_idx, model, sample_batch, device):
  962. torch.set_grad_enabled(grad_enabled)
  963. try:
  964. with NullContextManager() if device.type == "cpu" else torch.cuda.device( # noqa: E501
  965. device
  966. ):
  967. loss_out = force_list(
  968. self._loss(self, model, self.dist_class, sample_batch)
  969. )
  970. # Call Model's custom-loss with Policy loss outputs and
  971. # train_batch.
  972. loss_out = model.custom_loss(loss_out, sample_batch)
  973. assert len(loss_out) == len(self._optimizers)
  974. # Loop through all optimizers.
  975. grad_info = {"allreduce_latency": 0.0}
  976. parameters = list(model.parameters())
  977. all_grads = [None for _ in range(len(parameters))]
  978. for opt_idx, opt in enumerate(self._optimizers):
  979. # Erase gradients in all vars of the tower that this
  980. # optimizer would affect.
  981. param_indices = self.multi_gpu_param_groups[opt_idx]
  982. for param_idx, param in enumerate(parameters):
  983. if param_idx in param_indices and param.grad is not None:
  984. param.grad.data.zero_()
  985. # Recompute gradients of loss over all variables.
  986. loss_out[opt_idx].backward(retain_graph=True)
  987. grad_info.update(
  988. self.extra_grad_process(opt, loss_out[opt_idx])
  989. )
  990. grads = []
  991. # Note that return values are just references;
  992. # Calling zero_grad would modify the values.
  993. for param_idx, param in enumerate(parameters):
  994. if param_idx in param_indices:
  995. if param.grad is not None:
  996. grads.append(param.grad)
  997. all_grads[param_idx] = param.grad
  998. if self.distributed_world_size:
  999. start = time.time()
  1000. if torch.cuda.is_available():
  1001. # Sadly, allreduce_coalesced does not work with
  1002. # CUDA yet.
  1003. for g in grads:
  1004. torch.distributed.all_reduce(
  1005. g, op=torch.distributed.ReduceOp.SUM
  1006. )
  1007. else:
  1008. torch.distributed.all_reduce_coalesced(
  1009. grads, op=torch.distributed.ReduceOp.SUM
  1010. )
  1011. for param_group in opt.param_groups:
  1012. for p in param_group["params"]:
  1013. if p.grad is not None:
  1014. p.grad /= self.distributed_world_size
  1015. grad_info["allreduce_latency"] += time.time() - start
  1016. with lock:
  1017. results[shard_idx] = (all_grads, grad_info)
  1018. except Exception as e:
  1019. import traceback
  1020. with lock:
  1021. results[shard_idx] = (
  1022. ValueError(
  1023. f"Error In tower {shard_idx} on device "
  1024. f"{device} during multi GPU parallel gradient "
  1025. f"calculation:"
  1026. f": {e}\n"
  1027. f"Traceback: \n"
  1028. f"{traceback.format_exc()}\n"
  1029. ),
  1030. e,
  1031. )
  1032. # Single device (GPU) or fake-GPU case (serialize for better
  1033. # debugging).
  1034. if len(self.devices) == 1 or self.config["_fake_gpus"]:
  1035. for shard_idx, (model, sample_batch, device) in enumerate(
  1036. zip(self.model_gpu_towers, sample_batches, self.devices)
  1037. ):
  1038. _worker(shard_idx, model, sample_batch, device)
  1039. # Raise errors right away for better debugging.
  1040. last_result = results[len(results) - 1]
  1041. if isinstance(last_result[0], ValueError):
  1042. raise last_result[0] from last_result[1]
  1043. # Multi device (GPU) case: Parallelize via threads.
  1044. else:
  1045. threads = [
  1046. threading.Thread(
  1047. target=_worker, args=(shard_idx, model, sample_batch, device)
  1048. )
  1049. for shard_idx, (model, sample_batch, device) in enumerate(
  1050. zip(self.model_gpu_towers, sample_batches, self.devices)
  1051. )
  1052. ]
  1053. for thread in threads:
  1054. thread.start()
  1055. for thread in threads:
  1056. thread.join()
  1057. # Gather all threads' outputs and return.
  1058. outputs = []
  1059. for shard_idx in range(len(sample_batches)):
  1060. output = results[shard_idx]
  1061. if isinstance(output[0], Exception):
  1062. raise output[0] from output[1]
  1063. outputs.append(results[shard_idx])
  1064. return outputs
  1065. @DeveloperAPI
  1066. class DirectStepOptimizer:
  1067. """Typesafe method for indicating `apply_gradients` can directly step the
  1068. optimizers with in-place gradients.
  1069. """
  1070. _instance = None
  1071. def __new__(cls):
  1072. if DirectStepOptimizer._instance is None:
  1073. DirectStepOptimizer._instance = super().__new__(cls)
  1074. return DirectStepOptimizer._instance
  1075. def __eq__(self, other):
  1076. return type(self) is type(other)
  1077. def __repr__(self):
  1078. return "DirectStepOptimizer"
  1079. _directStepOptimizerSingleton = DirectStepOptimizer()