bandit.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import logging
  2. from typing import Optional, Type, Union
  3. from ray.rllib.algorithms.algorithm import Algorithm
  4. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  5. from ray.rllib.algorithms.bandit.bandit_tf_policy import BanditTFPolicy
  6. from ray.rllib.algorithms.bandit.bandit_torch_policy import BanditTorchPolicy
  7. from ray.rllib.policy.policy import Policy
  8. from ray.rllib.utils.annotations import override
  9. from ray.rllib.utils.deprecation import Deprecated, ALGO_DEPRECATION_WARNING
  10. logger = logging.getLogger(__name__)
  11. class BanditConfig(AlgorithmConfig):
  12. """Defines a contextual bandit configuration class from which
  13. a contexual bandit algorithm can be built. Note this config is shared
  14. between BanditLinUCB and BanditLinTS. You likely
  15. want to use the child classes BanditLinTSConfig or BanditLinUCBConfig
  16. instead.
  17. """
  18. def __init__(self, algo_class: Union["BanditLinTS", "BanditLinUCB"] = None):
  19. super().__init__(algo_class=algo_class)
  20. # fmt: off
  21. # __sphinx_doc_begin__
  22. # Override some of AlgorithmConfig's default values with bandit-specific values.
  23. self.framework_str = "torch"
  24. self.rollout_fragment_length = 1
  25. self.train_batch_size = 1
  26. # Make sure, a `train()` call performs at least 100 env sampling
  27. # timesteps, before reporting results. Not setting this (default is 0)
  28. # would significantly slow down the Bandit Algorithm.
  29. self.min_sample_timesteps_per_iteration = 100
  30. # __sphinx_doc_end__
  31. # fmt: on
  32. class BanditLinTSConfig(BanditConfig):
  33. """Defines a configuration class from which a Thompson-sampling bandit can be built.
  34. Example:
  35. >>> from ray.rllib.algorithms.bandit import BanditLinTSConfig # doctest: +SKIP
  36. >>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
  37. >>> config = BanditLinTSConfig().rollouts(num_rollout_workers=4)# doctest: +SKIP
  38. >>> print(config.to_dict()) # doctest: +SKIP
  39. >>> # Build a Algorithm object from the config and run 1 training iteration.
  40. >>> algo = config.build(env=WheelBanditEnv) # doctest: +SKIP
  41. >>> algo.train() # doctest: +SKIP
  42. """
  43. def __init__(self):
  44. super().__init__(algo_class=BanditLinTS)
  45. # fmt: off
  46. # __sphinx_doc_begin__
  47. # Override some of AlgorithmConfig's default values with bandit-specific values.
  48. self.exploration_config = {"type": "ThompsonSampling"}
  49. # __sphinx_doc_end__
  50. # fmt: on
  51. class BanditLinUCBConfig(BanditConfig):
  52. """Defines a config class from which an upper confidence bound bandit can be built.
  53. Example:
  54. >>> from ray.rllib.algorithms.bandit import BanditLinUCBConfig# doctest: +SKIP
  55. >>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
  56. >>> config = BanditLinUCBConfig() # doctest: +SKIP
  57. >>> config = config.rollouts(num_rollout_workers=4) # doctest: +SKIP
  58. >>> print(config.to_dict()) # doctest: +SKIP
  59. >>> # Build a Algorithm object from the config and run 1 training iteration.
  60. >>> algo = config.build(env=WheelBanditEnv) # doctest: +SKIP
  61. >>> algo.train() # doctest: +SKIP
  62. """
  63. def __init__(self):
  64. super().__init__(algo_class=BanditLinUCB)
  65. # fmt: off
  66. # __sphinx_doc_begin__
  67. # Override some of AlgorithmConfig's default values with bandit-specific values.
  68. self.exploration_config = {"type": "UpperConfidenceBound"}
  69. # __sphinx_doc_end__
  70. # fmt: on
  71. @Deprecated(
  72. old="rllib/algorithms/bandit/",
  73. new="rllib_contrib/bandit/",
  74. help=ALGO_DEPRECATION_WARNING,
  75. error=False,
  76. )
  77. class BanditLinTS(Algorithm):
  78. """Bandit Algorithm using ThompsonSampling exploration."""
  79. @classmethod
  80. @override(Algorithm)
  81. def get_default_config(cls) -> BanditLinTSConfig:
  82. return BanditLinTSConfig()
  83. @classmethod
  84. @override(Algorithm)
  85. def get_default_policy_class(
  86. cls, config: AlgorithmConfig
  87. ) -> Optional[Type[Policy]]:
  88. if config["framework"] == "torch":
  89. return BanditTorchPolicy
  90. elif config["framework"] == "tf2":
  91. return BanditTFPolicy
  92. else:
  93. raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
  94. @Deprecated(
  95. old="rllib/algorithms/bandit/",
  96. new="rllib_contrib/bandit/",
  97. help=ALGO_DEPRECATION_WARNING,
  98. error=False,
  99. )
  100. class BanditLinUCB(Algorithm):
  101. @classmethod
  102. @override(Algorithm)
  103. def get_default_config(cls) -> BanditLinUCBConfig:
  104. return BanditLinUCBConfig()
  105. @classmethod
  106. @override(Algorithm)
  107. def get_default_policy_class(
  108. cls, config: AlgorithmConfig
  109. ) -> Optional[Type[Policy]]:
  110. if config["framework"] == "torch":
  111. return BanditTorchPolicy
  112. elif config["framework"] == "tf2":
  113. return BanditTFPolicy
  114. else:
  115. raise NotImplementedError("Only `framework=[torch|tf2]` supported!")