replay_buffer_demo.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Demonstration of RLlib's ReplayBuffer workflow
  2. from typing import Optional
  3. import random
  4. import numpy as np
  5. from ray import air, tune
  6. from ray.rllib.utils.replay_buffers import ReplayBuffer, StorageUnit
  7. from ray.rllib.utils.annotations import override
  8. from ray.rllib.utils.typing import SampleBatchType
  9. from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
  10. from ray.rllib.examples.env.random_env import RandomEnv
  11. from ray.rllib.policy.sample_batch import SampleBatch, concat_samples
  12. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  13. # __sphinx_doc_replay_buffer_type_specification__begin__
  14. config = DQNConfig().training(replay_buffer_config={"type": ReplayBuffer})
  15. another_config = DQNConfig().training(replay_buffer_config={"type": "ReplayBuffer"})
  16. yet_another_config = DQNConfig().training(
  17. replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"}
  18. )
  19. validate_buffer_config(config)
  20. validate_buffer_config(another_config)
  21. validate_buffer_config(yet_another_config)
  22. # After validation, all three configs yield the same effective config
  23. assert (
  24. config.replay_buffer_config
  25. == another_config.replay_buffer_config
  26. == yet_another_config.replay_buffer_config
  27. )
  28. # __sphinx_doc_replay_buffer_type_specification__end__
  29. # __sphinx_doc_replay_buffer_basic_interaction__begin__
  30. # We choose fragments because it does not impose restrictions on our batch to be added
  31. buffer = ReplayBuffer(capacity=2, storage_unit=StorageUnit.FRAGMENTS)
  32. dummy_batch = SampleBatch({"a": [1], "b": [2]})
  33. buffer.add(dummy_batch)
  34. buffer.sample(2)
  35. # Because elements can be sampled multiple times, we receive a concatenated version
  36. # of dummy_batch `{a: [1, 1], b: [2, 2,]}`.
  37. # __sphinx_doc_replay_buffer_basic_interaction__end__
  38. # __sphinx_doc_replay_buffer_own_buffer__begin__
  39. class LessSampledReplayBuffer(ReplayBuffer):
  40. @override(ReplayBuffer)
  41. def sample(
  42. self, num_items: int, evict_sampled_more_then: int = 30, **kwargs
  43. ) -> Optional[SampleBatchType]:
  44. """Evicts experiences that have been sampled > evict_sampled_more_then times."""
  45. idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
  46. often_sampled_idxes = list(
  47. filter(lambda x: self._hit_count[x] >= evict_sampled_more_then, set(idxes))
  48. )
  49. sample = self._encode_sample(idxes)
  50. self._num_timesteps_sampled += sample.count
  51. for idx in often_sampled_idxes:
  52. del self._storage[idx]
  53. self._hit_count = np.append(
  54. self._hit_count[:idx], self._hit_count[idx + 1 :]
  55. )
  56. return sample
  57. config = (
  58. DQNConfig()
  59. .training(replay_buffer_config={"type": LessSampledReplayBuffer})
  60. .environment(env="CartPole-v1")
  61. )
  62. tune.Tuner(
  63. "DQN",
  64. param_space=config.to_dict(),
  65. run_config=air.RunConfig(
  66. stop={"training_iteration": 1},
  67. ),
  68. ).fit()
  69. # __sphinx_doc_replay_buffer_own_buffer__end__
  70. # __sphinx_doc_replay_buffer_advanced_usage_storage_unit__begin__
  71. # This line will make our buffer store only complete episodes found in a batch
  72. config.training(replay_buffer_config={"storage_unit": StorageUnit.EPISODES})
  73. less_sampled_buffer = LessSampledReplayBuffer(**config.replay_buffer_config)
  74. # Gather some random experiences
  75. env = RandomEnv()
  76. terminated = truncated = False
  77. batch = SampleBatch({})
  78. t = 0
  79. while not terminated and not truncated:
  80. obs, reward, terminated, truncated, info = env.step([0, 0])
  81. # Note that in order for RLlib to find out about start and end of an episode,
  82. # "t" and "terminateds" have to properly mark an episode's trajectory
  83. one_step_batch = SampleBatch(
  84. {
  85. "obs": [obs],
  86. "t": [t],
  87. "reward": [reward],
  88. "terminateds": [terminated],
  89. "truncateds": [truncated],
  90. }
  91. )
  92. batch = concat_samples([batch, one_step_batch])
  93. t += 1
  94. less_sampled_buffer.add(batch)
  95. for i in range(10):
  96. assert len(less_sampled_buffer._storage) == 1
  97. less_sampled_buffer.sample(num_items=1, evict_sampled_more_then=9)
  98. assert len(less_sampled_buffer._storage) == 0
  99. # __sphinx_doc_replay_buffer_advanced_usage_storage_unit__end__
  100. # __sphinx_doc_replay_buffer_advanced_usage_underlying_buffers__begin__
  101. config = (
  102. DQNConfig()
  103. .training(
  104. replay_buffer_config={
  105. "type": "MultiAgentReplayBuffer",
  106. "underlying_replay_buffer_config": {
  107. "type": LessSampledReplayBuffer,
  108. # We can specify the default call argument
  109. # for the sample method of the underlying buffer method here.
  110. "evict_sampled_more_then": 20,
  111. },
  112. }
  113. )
  114. .environment(env="CartPole-v1")
  115. )
  116. tune.Tuner(
  117. "DQN",
  118. param_space=config.to_dict(),
  119. run_config=air.RunConfig(stop={"episode_reward_mean": 40, "training_iteration": 7}),
  120. ).fit()
  121. # __sphinx_doc_replay_buffer_advanced_usage_underlying_buffers__end__