typing.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from typing import Any, Dict, List, Tuple, Union, TYPE_CHECKING
  2. if TYPE_CHECKING:
  3. from ray.rllib.utils import try_import_tf, try_import_torch
  4. _, tf, _ = try_import_tf()
  5. torch, _ = try_import_torch()
  6. from ray.rllib.policy.policy import PolicySpec
  7. from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
  8. from ray.rllib.policy.view_requirement import ViewRequirement
  9. # Represents a fully filled out config of a Trainer class.
  10. # Note: Policy config dicts are usually the same as TrainerConfigDict, but
  11. # parts of it may sometimes be altered in e.g. a multi-agent setup,
  12. # where we have >1 Policies in the same Trainer.
  13. TrainerConfigDict = dict
  14. # A trainer config dict that only has overrides. It needs to be combined with
  15. # the default trainer config to be used.
  16. PartialTrainerConfigDict = dict
  17. # Represents the env_config sub-dict of the trainer config that is passed to
  18. # the env constructor.
  19. EnvConfigDict = dict
  20. # Represents the model config sub-dict of the trainer config that is passed to
  21. # the model catalog.
  22. ModelConfigDict = dict
  23. # Objects that can be created through the `from_config()` util method
  24. # need a config dict with a "type" key, a class path (str), or a type directly.
  25. FromConfigSpec = Union[Dict[str, Any], type, str]
  26. # Represents a BaseEnv, MultiAgentEnv, ExternalEnv, ExternalMultiAgentEnv,
  27. # VectorEnv, or gym.Env.
  28. EnvType = Any
  29. # Represents a generic identifier for an agent (e.g., "agent1").
  30. AgentID = Any
  31. # Represents a generic identifier for a policy (e.g., "pol1").
  32. PolicyID = str
  33. # Type of the config["multiagent"]["policies"] dict for multi-agent training.
  34. MultiAgentPolicyConfigDict = Dict[PolicyID, "PolicySpec"]
  35. # Represents an environment id. These could be:
  36. # - An int index for a sub-env within a vectorized env.
  37. # - An external env ID (str), which changes(!) each episode.
  38. EnvID = Union[int, str]
  39. # Represents an episode id.
  40. EpisodeID = int
  41. # Represents an "unroll" (maybe across different sub-envs in a vector env).
  42. UnrollID = int
  43. # A dict keyed by agent ids, e.g. {"agent-1": value}.
  44. MultiAgentDict = Dict[AgentID, Any]
  45. # A dict keyed by env ids that contain further nested dictionaries keyed by
  46. # agent ids. e.g., {"env-1": {"agent-1": value}}.
  47. MultiEnvDict = Dict[EnvID, MultiAgentDict]
  48. # Represents an observation returned from the env.
  49. EnvObsType = Any
  50. # Represents an action passed to the env.
  51. EnvActionType = Any
  52. # Info dictionary returned by calling step() on gym envs. Commonly empty dict.
  53. EnvInfoDict = dict
  54. # Represents a File object
  55. FileType = Any
  56. # Represents a ViewRequirements dict mapping column names (str) to
  57. # ViewRequirement objects.
  58. ViewRequirementsDict = Dict[str, "ViewRequirement"]
  59. # Represents the result dict returned by Trainer.train().
  60. ResultDict = dict
  61. # A tf or torch local optimizer object.
  62. LocalOptimizer = Union["tf.keras.optimizers.Optimizer",
  63. "torch.optim.Optimizer"]
  64. # Dict of tensors returned by compute gradients on the policy, e.g.,
  65. # {"td_error": [...], "learner_stats": {"vf_loss": ..., ...}}, for multi-agent,
  66. # {"policy1": {"learner_stats": ..., }, "policy2": ...}.
  67. GradInfoDict = dict
  68. # Dict of learner stats returned by compute gradients on the policy, e.g.,
  69. # {"vf_loss": ..., ...}. This will always be nested under the "learner_stats"
  70. # key(s) of a GradInfoDict. In the multi-agent case, this will be keyed by
  71. # policy id.
  72. LearnerStatsDict = dict
  73. # Represents a generic tensor type.
  74. # This could be an np.ndarray, tf.Tensor, or a torch.Tensor.
  75. TensorType = Any
  76. # List of grads+var tuples (tf) or list of gradient tensors (torch)
  77. # representing model gradients and returned by compute_gradients().
  78. ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]]
  79. # Type of dict returned by get_weights() representing model weights.
  80. ModelWeights = dict
  81. # An input dict used for direct ModelV2 calls.
  82. ModelInputDict = Dict[str, TensorType]
  83. # Some kind of sample batch.
  84. SampleBatchType = Union["SampleBatch", "MultiAgentBatch"]
  85. # Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
  86. TensorStructType = Union[TensorType, dict, tuple]
  87. # A shape of a tensor.
  88. TensorShape = Union[Tuple[int], List[int]]