bandit_tf_policy.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import gymnasium as gym
  2. import logging
  3. import time
  4. from typing import Dict
  5. from gymnasium import spaces
  6. import ray
  7. from ray.rllib.algorithms.bandit.bandit_tf_model import (
  8. DiscreteLinearModelThompsonSampling,
  9. DiscreteLinearModelUCB,
  10. DiscreteLinearModel,
  11. ParametricLinearModelThompsonSampling,
  12. ParametricLinearModelUCB,
  13. )
  14. from ray.rllib.models.catalog import ModelCatalog
  15. from ray.rllib.models.modelv2 import restore_original_dimensions
  16. from ray.rllib.policy.policy import Policy
  17. from ray.rllib.policy.sample_batch import SampleBatch
  18. from ray.rllib.policy.tf_policy_template import build_tf_policy
  19. from ray.rllib.utils.error import UnsupportedSpaceException
  20. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  21. from ray.rllib.utils.tf_utils import make_tf_callable
  22. from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict
  23. from ray.util.debug import log_once
  24. logger = logging.getLogger(__name__)
  25. class BanditPolicyOverrides:
  26. def __init__(self):
  27. @make_tf_callable(self.get_session())
  28. def learn_on_batch(postprocessed_batch) -> Dict[str, TensorType]:
  29. # INFOS dict can't be converted to Tensor with the interceptor.
  30. postprocessed_batch.set_get_interceptor(None)
  31. unflattened_obs = restore_original_dimensions(
  32. postprocessed_batch[SampleBatch.CUR_OBS],
  33. self.observation_space,
  34. self.framework,
  35. )
  36. info = {}
  37. start = time.time()
  38. self.model.partial_fit(
  39. unflattened_obs,
  40. postprocessed_batch[SampleBatch.REWARDS],
  41. postprocessed_batch[SampleBatch.ACTIONS],
  42. )
  43. infos = postprocessed_batch[SampleBatch.INFOS]
  44. if "regret" in infos[0]:
  45. regret = sum(
  46. row["infos"]["regret"] for row in postprocessed_batch.rows()
  47. )
  48. self.regrets.append(regret)
  49. info["cumulative_regret"] = sum(self.regrets)
  50. else:
  51. if log_once("no_regrets"):
  52. logger.warning(
  53. "The env did not report `regret` values in "
  54. "its `info` return, ignoring."
  55. )
  56. info["update_latency"] = time.time() - start
  57. return {LEARNER_STATS_KEY: info}
  58. self.learn_on_batch = learn_on_batch
  59. def validate_spaces(
  60. policy: Policy,
  61. observation_space: gym.spaces.Space,
  62. action_space: gym.spaces.Space,
  63. config: AlgorithmConfigDict,
  64. ) -> None:
  65. """Validates the observation- and action spaces used for the Policy.
  66. Args:
  67. policy: The policy, whose spaces are being validated.
  68. observation_space: The observation space to validate.
  69. action_space: The action space to validate.
  70. config: The Policy's config dict.
  71. Raises:
  72. UnsupportedSpaceException: If one of the spaces is not supported.
  73. """
  74. # Only support single Box or single Discrete spaces.
  75. if not isinstance(action_space, gym.spaces.Discrete):
  76. msg = (
  77. f"Action space ({action_space}) of {policy} is not supported for "
  78. f"Bandit algorithms. Must be `Discrete`."
  79. )
  80. # Hint at using the MultiDiscrete to Discrete wrapper for Bandits.
  81. if isinstance(action_space, gym.spaces.MultiDiscrete):
  82. msg += (
  83. " Try to wrap your environment with the "
  84. "`ray.rllib.env.wrappers.recsim::"
  85. "MultiDiscreteToDiscreteActionWrapper` class: `tune.register_env("
  86. "[some str], lambda ctx: MultiDiscreteToDiscreteActionWrapper("
  87. "[your gym env])); config = {'env': [some str]}`"
  88. )
  89. raise UnsupportedSpaceException(msg)
  90. def make_model(policy, obs_space, action_space, config):
  91. _, logit_dim = ModelCatalog.get_action_dist(
  92. action_space, config["model"], framework="tf"
  93. )
  94. model_cls = DiscreteLinearModel
  95. if hasattr(obs_space, "original_space"):
  96. original_space = obs_space.original_space
  97. else:
  98. original_space = obs_space
  99. exploration_config = config.get("exploration_config")
  100. # Model is dependent on exploration strategy because of its implicitness
  101. # TODO: Have a separate model catalogue for bandits
  102. if exploration_config:
  103. if exploration_config["type"] == "ThompsonSampling":
  104. if isinstance(original_space, spaces.Dict):
  105. assert (
  106. "item" in original_space.spaces
  107. ), "Cannot find 'item' key in observation space"
  108. model_cls = ParametricLinearModelThompsonSampling
  109. else:
  110. model_cls = DiscreteLinearModelThompsonSampling
  111. elif exploration_config["type"] == "UpperConfidenceBound":
  112. if isinstance(original_space, spaces.Dict):
  113. assert (
  114. "item" in original_space.spaces
  115. ), "Cannot find 'item' key in observation space"
  116. model_cls = ParametricLinearModelUCB
  117. else:
  118. model_cls = DiscreteLinearModelUCB
  119. model = model_cls(
  120. obs_space, action_space, logit_dim, config["model"], name="LinearModel"
  121. )
  122. return model
  123. def after_init(policy, *args):
  124. policy.regrets = []
  125. BanditPolicyOverrides.__init__(policy)
  126. BanditTFPolicy = build_tf_policy(
  127. name="BanditTFPolicy",
  128. get_default_config=lambda: ray.rllib.algorithms.bandit.bandit.BanditConfig(),
  129. validate_spaces=validate_spaces,
  130. make_model=make_model,
  131. loss_fn=None,
  132. mixins=[BanditPolicyOverrides],
  133. after_init=after_init,
  134. )