policy_template.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. import gym
  2. from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \
  3. TYPE_CHECKING, Union
  4. from ray.rllib.models.catalog import ModelCatalog
  5. from ray.rllib.models.jax.jax_modelv2 import JAXModelV2
  6. from ray.rllib.models.modelv2 import ModelV2
  7. from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
  8. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  9. from ray.rllib.policy.policy import Policy
  10. from ray.rllib.policy.sample_batch import SampleBatch
  11. from ray.rllib.policy.torch_policy import TorchPolicy
  12. from ray.rllib.utils import add_mixins, NullContextManager
  13. from ray.rllib.utils.annotations import override, DeveloperAPI
  14. from ray.rllib.utils.framework import try_import_torch, try_import_jax
  15. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  16. from ray.rllib.utils.numpy import convert_to_numpy
  17. from ray.rllib.utils.typing import ModelGradients, TensorType, \
  18. TrainerConfigDict
  19. if TYPE_CHECKING:
  20. from ray.rllib.evaluation.episode import Episode # noqa
  21. jax, _ = try_import_jax()
  22. torch, _ = try_import_torch()
  23. # TODO: Deprecate in favor of directly sub-classing from TorchPolicy.
  24. @DeveloperAPI
  25. def build_policy_class(
  26. name: str,
  27. framework: str,
  28. *,
  29. loss_fn: Optional[Callable[[
  30. Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
  31. ], Union[TensorType, List[TensorType]]]],
  32. get_default_config: Optional[Callable[[], TrainerConfigDict]] = None,
  33. stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
  34. str, TensorType]]] = None,
  35. postprocess_fn: Optional[Callable[[
  36. Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], Optional[
  37. "Episode"]
  38. ], SampleBatch]] = None,
  39. extra_action_out_fn: Optional[Callable[[
  40. Policy, Dict[str, TensorType], List[TensorType], ModelV2,
  41. TorchDistributionWrapper
  42. ], Dict[str, TensorType]]] = None,
  43. extra_grad_process_fn: Optional[Callable[[
  44. Policy, "torch.optim.Optimizer", TensorType
  45. ], Dict[str, TensorType]]] = None,
  46. # TODO: (sven) Replace "fetches" with "process".
  47. extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[
  48. str, TensorType]]] = None,
  49. optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict],
  50. "torch.optim.Optimizer"]] = None,
  51. validate_spaces: Optional[Callable[
  52. [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
  53. before_init: Optional[Callable[
  54. [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
  55. before_loss_init: Optional[Callable[[
  56. Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
  57. ], None]] = None,
  58. after_init: Optional[Callable[
  59. [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
  60. _after_loss_init: Optional[Callable[[
  61. Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
  62. ], None]] = None,
  63. action_sampler_fn: Optional[Callable[[TensorType, List[
  64. TensorType]], Tuple[TensorType, TensorType]]] = None,
  65. action_distribution_fn: Optional[Callable[[
  66. Policy, ModelV2, TensorType, TensorType, TensorType
  67. ], Tuple[TensorType, type, List[TensorType]]]] = None,
  68. make_model: Optional[Callable[[
  69. Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
  70. ], ModelV2]] = None,
  71. make_model_and_action_dist: Optional[Callable[[
  72. Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
  73. ], Tuple[ModelV2, Type[TorchDistributionWrapper]]]] = None,
  74. compute_gradients_fn: Optional[Callable[[Policy, SampleBatch], Tuple[
  75. ModelGradients, dict]]] = None,
  76. apply_gradients_fn: Optional[Callable[
  77. [Policy, "torch.optim.Optimizer"], None]] = None,
  78. mixins: Optional[List[type]] = None,
  79. get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
  80. ) -> Type[TorchPolicy]:
  81. """Helper function for creating a new Policy class at runtime.
  82. Supports frameworks JAX and PyTorch.
  83. Args:
  84. name (str): name of the policy (e.g., "PPOTorchPolicy")
  85. framework (str): Either "jax" or "torch".
  86. loss_fn (Optional[Callable[[Policy, ModelV2,
  87. Type[TorchDistributionWrapper], SampleBatch], Union[TensorType,
  88. List[TensorType]]]]): Callable that returns a loss tensor.
  89. get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
  90. Optional callable that returns the default config to merge with any
  91. overrides. If None, uses only(!) the user-provided
  92. PartialTrainerConfigDict as dict for this Policy.
  93. postprocess_fn (Optional[Callable[[Policy, SampleBatch,
  94. Optional[Dict[Any, SampleBatch]], Optional["Episode"]],
  95. SampleBatch]]): Optional callable for post-processing experience
  96. batches (called after the super's `postprocess_trajectory` method).
  97. stats_fn (Optional[Callable[[Policy, SampleBatch],
  98. Dict[str, TensorType]]]): Optional callable that returns a dict of
  99. values given the policy and training batch. If None,
  100. will use `TorchPolicy.extra_grad_info()` instead. The stats dict is
  101. used for logging (e.g. in TensorBoard).
  102. extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType],
  103. List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
  104. TensorType]]]): Optional callable that returns a dict of extra
  105. values to include in experiences. If None, no extra computations
  106. will be performed.
  107. extra_grad_process_fn (Optional[Callable[[Policy,
  108. "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]):
  109. Optional callable that is called after gradients are computed and
  110. returns a processing info dict. If None, will call the
  111. `TorchPolicy.extra_grad_process()` method instead.
  112. # TODO: (sven) dissolve naming mismatch between "learn" and "compute.."
  113. extra_learn_fetches_fn (Optional[Callable[[Policy],
  114. Dict[str, TensorType]]]): Optional callable that returns a dict of
  115. extra tensors from the policy after loss evaluation. If None,
  116. will call the `TorchPolicy.extra_compute_grad_fetches()` method
  117. instead.
  118. optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
  119. "torch.optim.Optimizer"]]): Optional callable that returns a
  120. torch optimizer given the policy and config. If None, will call
  121. the `TorchPolicy.optimizer()` method instead (which returns a
  122. torch Adam optimizer).
  123. validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
  124. TrainerConfigDict], None]]): Optional callable that takes the
  125. Policy, observation_space, action_space, and config to check for
  126. correctness. If None, no spaces checking will be done.
  127. before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
  128. TrainerConfigDict], None]]): Optional callable to run at the
  129. beginning of `Policy.__init__` that takes the same arguments as
  130. the Policy constructor. If None, this step will be skipped.
  131. before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
  132. gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
  133. run prior to loss init. If None, this step will be skipped.
  134. after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
  135. TrainerConfigDict], None]]): DEPRECATED: Use `before_loss_init`
  136. instead.
  137. _after_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
  138. gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
  139. run after the loss init. If None, this step will be skipped.
  140. This will be deprecated at some point and renamed into `after_init`
  141. to match `build_tf_policy()` behavior.
  142. action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
  143. Tuple[TensorType, TensorType]]]): Optional callable returning a
  144. sampled action and its log-likelihood given some (obs and state)
  145. inputs. If None, will either use `action_distribution_fn` or
  146. compute actions by calling self.model, then sampling from the
  147. so parameterized action distribution.
  148. action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
  149. TensorType, TensorType], Tuple[TensorType,
  150. Type[TorchDistributionWrapper], List[TensorType]]]]): A callable
  151. that takes the Policy, Model, the observation batch, an
  152. explore-flag, a timestep, and an is_training flag and returns a
  153. tuple of a) distribution inputs (parameters), b) a dist-class to
  154. generate an action distribution object from, and c) internal-state
  155. outputs (empty list if not applicable). If None, will either use
  156. `action_sampler_fn` or compute actions by calling self.model,
  157. then sampling from the parameterized action distribution.
  158. make_model (Optional[Callable[[Policy, gym.spaces.Space,
  159. gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
  160. that takes the same arguments as Policy.__init__ and returns a
  161. model instance. The distribution class will be determined
  162. automatically. Note: Only one of `make_model` or
  163. `make_model_and_action_dist` should be provided. If both are None,
  164. a default Model will be created.
  165. make_model_and_action_dist (Optional[Callable[[Policy,
  166. gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
  167. Tuple[ModelV2, Type[TorchDistributionWrapper]]]]): Optional
  168. callable that takes the same arguments as Policy.__init__ and
  169. returns a tuple of model instance and torch action distribution
  170. class.
  171. Note: Only one of `make_model` or `make_model_and_action_dist`
  172. should be provided. If both are None, a default Model will be
  173. created.
  174. compute_gradients_fn (Optional[Callable[
  175. [Policy, SampleBatch], Tuple[ModelGradients, dict]]]): Optional
  176. callable that the sampled batch an computes the gradients w.r.
  177. to the loss function.
  178. If None, will call the `TorchPolicy.compute_gradients()` method
  179. instead.
  180. apply_gradients_fn (Optional[Callable[[Policy,
  181. "torch.optim.Optimizer"], None]]): Optional callable that
  182. takes a grads list and applies these to the Model's parameters.
  183. If None, will call the `TorchPolicy.apply_gradients()` method
  184. instead.
  185. mixins (Optional[List[type]]): Optional list of any class mixins for
  186. the returned policy class. These mixins will be applied in order
  187. and will have higher precedence than the TorchPolicy class.
  188. get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
  189. Optional callable that returns the divisibility requirement for
  190. sample batches. If None, will assume a value of 1.
  191. Returns:
  192. Type[TorchPolicy]: TorchPolicy child class constructed from the
  193. specified args.
  194. """
  195. original_kwargs = locals().copy()
  196. parent_cls = TorchPolicy
  197. base = add_mixins(parent_cls, mixins)
  198. class policy_cls(base):
  199. def __init__(self, obs_space, action_space, config):
  200. # Set up the config from possible default-config fn and given
  201. # config arg.
  202. if get_default_config:
  203. config = dict(get_default_config(), **config)
  204. self.config = config
  205. # Set the DL framework for this Policy.
  206. self.framework = self.config["framework"] = framework
  207. # Validate observation- and action-spaces.
  208. if validate_spaces:
  209. validate_spaces(self, obs_space, action_space, self.config)
  210. # Do some pre-initialization steps.
  211. if before_init:
  212. before_init(self, obs_space, action_space, self.config)
  213. # Model is customized (use default action dist class).
  214. if make_model:
  215. assert make_model_and_action_dist is None, \
  216. "Either `make_model` or `make_model_and_action_dist`" \
  217. " must be None!"
  218. self.model = make_model(self, obs_space, action_space, config)
  219. dist_class, _ = ModelCatalog.get_action_dist(
  220. action_space, self.config["model"], framework=framework)
  221. # Model and action dist class are customized.
  222. elif make_model_and_action_dist:
  223. self.model, dist_class = make_model_and_action_dist(
  224. self, obs_space, action_space, config)
  225. # Use default model and default action dist.
  226. else:
  227. dist_class, logit_dim = ModelCatalog.get_action_dist(
  228. action_space, self.config["model"], framework=framework)
  229. self.model = ModelCatalog.get_model_v2(
  230. obs_space=obs_space,
  231. action_space=action_space,
  232. num_outputs=logit_dim,
  233. model_config=self.config["model"],
  234. framework=framework)
  235. # Make sure, we passed in a correct Model factory.
  236. model_cls = TorchModelV2 if framework == "torch" else JAXModelV2
  237. assert isinstance(self.model, model_cls), \
  238. "ERROR: Generated Model must be a TorchModelV2 object!"
  239. # Call the framework-specific Policy constructor.
  240. self.parent_cls = parent_cls
  241. self.parent_cls.__init__(
  242. self,
  243. observation_space=obs_space,
  244. action_space=action_space,
  245. config=config,
  246. model=self.model,
  247. loss=None if self.config["in_evaluation"] else loss_fn,
  248. action_distribution_class=dist_class,
  249. action_sampler_fn=action_sampler_fn,
  250. action_distribution_fn=action_distribution_fn,
  251. max_seq_len=config["model"]["max_seq_len"],
  252. get_batch_divisibility_req=get_batch_divisibility_req,
  253. )
  254. # Merge Model's view requirements into Policy's.
  255. self.view_requirements.update(self.model.view_requirements)
  256. _before_loss_init = before_loss_init or after_init
  257. if _before_loss_init:
  258. _before_loss_init(self, self.observation_space,
  259. self.action_space, config)
  260. # Perform test runs through postprocessing- and loss functions.
  261. self._initialize_loss_from_dummy_batch(
  262. auto_remove_unneeded_view_reqs=True,
  263. stats_fn=None if self.config["in_evaluation"] else stats_fn,
  264. )
  265. if _after_loss_init:
  266. _after_loss_init(self, obs_space, action_space, config)
  267. # Got to reset global_timestep again after this fake run-through.
  268. self.global_timestep = 0
  269. @override(Policy)
  270. def postprocess_trajectory(self,
  271. sample_batch,
  272. other_agent_batches=None,
  273. episode=None):
  274. # Do all post-processing always with no_grad().
  275. # Not using this here will introduce a memory leak
  276. # in torch (issue #6962).
  277. with self._no_grad_context():
  278. # Call super's postprocess_trajectory first.
  279. sample_batch = super().postprocess_trajectory(
  280. sample_batch, other_agent_batches, episode)
  281. if postprocess_fn:
  282. return postprocess_fn(self, sample_batch,
  283. other_agent_batches, episode)
  284. return sample_batch
  285. @override(parent_cls)
  286. def extra_grad_process(self, optimizer, loss):
  287. """Called after optimizer.zero_grad() and loss.backward() calls.
  288. Allows for gradient processing before optimizer.step() is called.
  289. E.g. for gradient clipping.
  290. """
  291. if extra_grad_process_fn:
  292. return extra_grad_process_fn(self, optimizer, loss)
  293. else:
  294. return parent_cls.extra_grad_process(self, optimizer, loss)
  295. @override(parent_cls)
  296. def extra_compute_grad_fetches(self):
  297. if extra_learn_fetches_fn:
  298. fetches = convert_to_numpy(extra_learn_fetches_fn(self))
  299. # Auto-add empty learner stats dict if needed.
  300. return dict({LEARNER_STATS_KEY: {}}, **fetches)
  301. else:
  302. return parent_cls.extra_compute_grad_fetches(self)
  303. @override(parent_cls)
  304. def compute_gradients(self, batch):
  305. if compute_gradients_fn:
  306. return compute_gradients_fn(self, batch)
  307. else:
  308. return parent_cls.compute_gradients(self, batch)
  309. @override(parent_cls)
  310. def apply_gradients(self, gradients):
  311. if apply_gradients_fn:
  312. apply_gradients_fn(self, gradients)
  313. else:
  314. parent_cls.apply_gradients(self, gradients)
  315. @override(parent_cls)
  316. def extra_action_out(self, input_dict, state_batches, model,
  317. action_dist):
  318. with self._no_grad_context():
  319. if extra_action_out_fn:
  320. stats_dict = extra_action_out_fn(
  321. self, input_dict, state_batches, model, action_dist)
  322. else:
  323. stats_dict = parent_cls.extra_action_out(
  324. self, input_dict, state_batches, model, action_dist)
  325. return self._convert_to_numpy(stats_dict)
  326. @override(parent_cls)
  327. def optimizer(self):
  328. if optimizer_fn:
  329. optimizers = optimizer_fn(self, self.config)
  330. else:
  331. optimizers = parent_cls.optimizer(self)
  332. return optimizers
  333. @override(parent_cls)
  334. def extra_grad_info(self, train_batch):
  335. with self._no_grad_context():
  336. if stats_fn:
  337. stats_dict = stats_fn(self, train_batch)
  338. else:
  339. stats_dict = self.parent_cls.extra_grad_info(
  340. self, train_batch)
  341. return self._convert_to_numpy(stats_dict)
  342. def _no_grad_context(self):
  343. if self.framework == "torch":
  344. return torch.no_grad()
  345. return NullContextManager()
  346. def _convert_to_numpy(self, data):
  347. if self.framework == "torch":
  348. return convert_to_numpy(data)
  349. return data
  350. def with_updates(**overrides):
  351. """Creates a Torch|JAXPolicy cls based on settings of another one.
  352. Keyword Args:
  353. **overrides: The settings (passed into `build_torch_policy`) that
  354. should be different from the class that this method is called
  355. on.
  356. Returns:
  357. type: A new Torch|JAXPolicy sub-class.
  358. Examples:
  359. >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates(
  360. .. name="MySpecialDQNPolicyClass",
  361. .. loss_function=[some_new_loss_function],
  362. .. )
  363. """
  364. return build_policy_class(**dict(original_kwargs, **overrides))
  365. policy_cls.with_updates = staticmethod(with_updates)
  366. policy_cls.__name__ = name
  367. policy_cls.__qualname__ = name
  368. return policy_cls