action_dist.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import numpy as np
  2. import gym
  3. from ray.rllib.models.modelv2 import ModelV2
  4. from ray.rllib.utils.annotations import DeveloperAPI
  5. from ray.rllib.utils.typing import TensorType, List, Union, ModelConfigDict
  6. @DeveloperAPI
  7. class ActionDistribution:
  8. """The policy action distribution of an agent.
  9. Attributes:
  10. inputs (Tensors): input vector to compute samples from.
  11. model (ModelV2): reference to model producing the inputs.
  12. """
  13. @DeveloperAPI
  14. def __init__(self, inputs: List[TensorType], model: ModelV2):
  15. """Initializes an ActionDist object.
  16. Args:
  17. inputs (Tensors): input vector to compute samples from.
  18. model (ModelV2): reference to model producing the inputs. This
  19. is mainly useful if you want to use model variables to compute
  20. action outputs (i.e., for auto-regressive action distributions,
  21. see examples/autoregressive_action_dist.py).
  22. """
  23. self.inputs = inputs
  24. self.model = model
  25. @DeveloperAPI
  26. def sample(self) -> TensorType:
  27. """Draw a sample from the action distribution."""
  28. raise NotImplementedError
  29. @DeveloperAPI
  30. def deterministic_sample(self) -> TensorType:
  31. """
  32. Get the deterministic "sampling" output from the distribution.
  33. This is usually the max likelihood output, i.e. mean for Normal, argmax
  34. for Categorical, etc..
  35. """
  36. raise NotImplementedError
  37. @DeveloperAPI
  38. def sampled_action_logp(self) -> TensorType:
  39. """Returns the log probability of the last sampled action."""
  40. raise NotImplementedError
  41. @DeveloperAPI
  42. def logp(self, x: TensorType) -> TensorType:
  43. """The log-likelihood of the action distribution."""
  44. raise NotImplementedError
  45. @DeveloperAPI
  46. def kl(self, other: "ActionDistribution") -> TensorType:
  47. """The KL-divergence between two action distributions."""
  48. raise NotImplementedError
  49. @DeveloperAPI
  50. def entropy(self) -> TensorType:
  51. """The entropy of the action distribution."""
  52. raise NotImplementedError
  53. def multi_kl(self, other: "ActionDistribution") -> TensorType:
  54. """The KL-divergence between two action distributions.
  55. This differs from kl() in that it can return an array for
  56. MultiDiscrete. TODO(ekl) consider removing this.
  57. """
  58. return self.kl(other)
  59. def multi_entropy(self) -> TensorType:
  60. """The entropy of the action distribution.
  61. This differs from entropy() in that it can return an array for
  62. MultiDiscrete. TODO(ekl) consider removing this.
  63. """
  64. return self.entropy()
  65. @DeveloperAPI
  66. @staticmethod
  67. def required_model_output_shape(
  68. action_space: gym.Space,
  69. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  70. """Returns the required shape of an input parameter tensor for a
  71. particular action space and an optional dict of distribution-specific
  72. options.
  73. Args:
  74. action_space (gym.Space): The action space this distribution will
  75. be used for, whose shape attributes will be used to determine
  76. the required shape of the input parameter tensor.
  77. model_config (dict): Model's config dict (as defined in catalog.py)
  78. Returns:
  79. model_output_shape (int or np.ndarray of ints): size of the
  80. required input vector (minus leading batch dimension).
  81. """
  82. raise NotImplementedError