torch_policy_template.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import gym
  2. from typing import Callable, Dict, List, Optional, Tuple, Type, Union
  3. from ray.rllib.models.modelv2 import ModelV2
  4. from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
  5. from ray.rllib.policy.policy import Policy
  6. from ray.rllib.policy.policy_template import build_policy_class
  7. from ray.rllib.policy.sample_batch import SampleBatch
  8. from ray.rllib.policy.torch_policy import TorchPolicy
  9. from ray.rllib.utils.deprecation import Deprecated
  10. from ray.rllib.utils.framework import try_import_torch
  11. from ray.rllib.utils.typing import ModelGradients, TensorType, \
  12. TrainerConfigDict
  13. torch, _ = try_import_torch()
  14. @Deprecated(new="build_policy_class(framework='torch')", error=False)
  15. def build_torch_policy(
  16. name: str,
  17. *,
  18. loss_fn: Optional[Callable[[
  19. Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
  20. ], Union[TensorType, List[TensorType]]]],
  21. get_default_config: Optional[Callable[[], TrainerConfigDict]] = None,
  22. stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
  23. str, TensorType]]] = None,
  24. postprocess_fn=None,
  25. extra_action_out_fn: Optional[Callable[[
  26. Policy, Dict[str, TensorType], List[TensorType], ModelV2,
  27. TorchDistributionWrapper
  28. ], Dict[str, TensorType]]] = None,
  29. extra_grad_process_fn: Optional[Callable[[
  30. Policy, "torch.optim.Optimizer", TensorType
  31. ], Dict[str, TensorType]]] = None,
  32. extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[
  33. str, TensorType]]] = None,
  34. optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict],
  35. "torch.optim.Optimizer"]] = None,
  36. validate_spaces: Optional[Callable[
  37. [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
  38. before_init: Optional[Callable[
  39. [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
  40. before_loss_init: Optional[Callable[[
  41. Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
  42. ], None]] = None,
  43. after_init: Optional[Callable[
  44. [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
  45. _after_loss_init: Optional[Callable[[
  46. Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
  47. ], None]] = None,
  48. action_sampler_fn: Optional[Callable[[TensorType, List[
  49. TensorType]], Tuple[TensorType, TensorType]]] = None,
  50. action_distribution_fn: Optional[Callable[[
  51. Policy, ModelV2, TensorType, TensorType, TensorType
  52. ], Tuple[TensorType, type, List[TensorType]]]] = None,
  53. make_model: Optional[Callable[[
  54. Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
  55. ], ModelV2]] = None,
  56. make_model_and_action_dist: Optional[Callable[[
  57. Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
  58. ], Tuple[ModelV2, Type[TorchDistributionWrapper]]]] = None,
  59. compute_gradients_fn: Optional[Callable[[Policy, SampleBatch], Tuple[
  60. ModelGradients, dict]]] = None,
  61. apply_gradients_fn: Optional[Callable[
  62. [Policy, "torch.optim.Optimizer"], None]] = None,
  63. mixins: Optional[List[type]] = None,
  64. get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
  65. ) -> Type[TorchPolicy]:
  66. kwargs = locals().copy()
  67. # Set to torch and call new function.
  68. kwargs["framework"] = "torch"
  69. return build_policy_class(**kwargs)