test_execution.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import numpy as np
  2. import time
  3. import gym
  4. import queue
  5. import ray
  6. from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
  7. from ray.rllib.evaluation.worker_set import WorkerSet
  8. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  9. from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
  10. STEPS_TRAINED_COUNTER
  11. from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
  12. from ray.rllib.execution.metric_ops import StandardMetricsReporting
  13. from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
  14. from ray.rllib.execution.rollout_ops import ParallelRollouts, AsyncGradients, \
  15. ConcatBatches, StandardizeFields
  16. from ray.rllib.execution.train_ops import TrainOneStep, ComputeGradients, \
  17. AverageGradients
  18. from ray.rllib.execution.replay_buffer import LocalReplayBuffer, \
  19. ReplayActor
  20. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
  21. from ray.util.iter import LocalIterator, from_range
  22. from ray.util.iter_metrics import SharedMetrics
  23. def iter_list(values):
  24. return LocalIterator(lambda _: values, SharedMetrics())
  25. def make_workers(n):
  26. local = RolloutWorker(
  27. env_creator=lambda _: gym.make("CartPole-v0"),
  28. policy_spec=PPOTFPolicy,
  29. rollout_fragment_length=100)
  30. remotes = [
  31. RolloutWorker.as_remote().remote(
  32. env_creator=lambda _: gym.make("CartPole-v0"),
  33. policy_spec=PPOTFPolicy,
  34. rollout_fragment_length=100) for _ in range(n)
  35. ]
  36. workers = WorkerSet._from_existing(local, remotes)
  37. return workers
  38. def test_concurrently(ray_start_regular_shared):
  39. a = iter_list([1, 2, 3])
  40. b = iter_list([4, 5, 6])
  41. c = Concurrently([a, b], mode="round_robin")
  42. assert c.take(6) == [1, 4, 2, 5, 3, 6]
  43. a = iter_list([1, 2, 3])
  44. b = iter_list([4, 5, 6])
  45. c = Concurrently([a, b], mode="async")
  46. assert c.take(6) == [1, 4, 2, 5, 3, 6]
  47. def test_concurrently_weighted(ray_start_regular_shared):
  48. a = iter_list([1, 1, 1])
  49. b = iter_list([2, 2, 2])
  50. c = iter_list([3, 3, 3])
  51. c = Concurrently(
  52. [a, b, c], mode="round_robin", round_robin_weights=[3, 1, 2])
  53. assert c.take(9) == [1, 1, 1, 2, 3, 3, 2, 3, 2]
  54. a = iter_list([1, 1, 1])
  55. b = iter_list([2, 2, 2])
  56. c = iter_list([3, 3, 3])
  57. c = Concurrently(
  58. [a, b, c], mode="round_robin", round_robin_weights=[1, 1, "*"])
  59. assert c.take(9) == [1, 2, 3, 3, 3, 1, 2, 1, 2]
  60. def test_concurrently_output(ray_start_regular_shared):
  61. a = iter_list([1, 2, 3])
  62. b = iter_list([4, 5, 6])
  63. c = Concurrently([a, b], mode="round_robin", output_indexes=[1])
  64. assert c.take(6) == [4, 5, 6]
  65. a = iter_list([1, 2, 3])
  66. b = iter_list([4, 5, 6])
  67. c = Concurrently([a, b], mode="round_robin", output_indexes=[0, 1])
  68. assert c.take(6) == [1, 4, 2, 5, 3, 6]
  69. def test_enqueue_dequeue(ray_start_regular_shared):
  70. a = iter_list([1, 2, 3])
  71. q = queue.Queue(100)
  72. a.for_each(Enqueue(q)).take(3)
  73. assert q.qsize() == 3
  74. assert q.get_nowait() == 1
  75. assert q.get_nowait() == 2
  76. assert q.get_nowait() == 3
  77. q.put("a")
  78. q.put("b")
  79. q.put("c")
  80. a = Dequeue(q)
  81. assert a.take(3) == ["a", "b", "c"]
  82. def test_metrics(ray_start_regular_shared):
  83. workers = make_workers(1)
  84. workers.foreach_worker(lambda w: w.sample())
  85. a = from_range(10, repeat=True).gather_sync()
  86. b = StandardMetricsReporting(
  87. a, workers, {
  88. "min_iter_time_s": 2.5,
  89. "timesteps_per_iteration": 0,
  90. "metrics_smoothing_episodes": 10,
  91. "collect_metrics_timeout": 10,
  92. })
  93. start = time.time()
  94. res1 = next(b)
  95. assert res1["episode_reward_mean"] > 0, res1
  96. res2 = next(b)
  97. assert res2["episode_reward_mean"] > 0, res2
  98. assert time.time() - start > 2.4
  99. workers.stop()
  100. def test_rollouts(ray_start_regular_shared):
  101. workers = make_workers(2)
  102. a = ParallelRollouts(workers, mode="bulk_sync")
  103. assert next(a).count == 200
  104. counters = a.shared_metrics.get().counters
  105. assert counters[STEPS_SAMPLED_COUNTER] == 200, counters
  106. a = ParallelRollouts(workers, mode="async")
  107. assert next(a).count == 100
  108. counters = a.shared_metrics.get().counters
  109. assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
  110. workers.stop()
  111. def test_rollouts_local(ray_start_regular_shared):
  112. workers = make_workers(0)
  113. a = ParallelRollouts(workers, mode="bulk_sync")
  114. assert next(a).count == 100
  115. counters = a.shared_metrics.get().counters
  116. assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
  117. workers.stop()
  118. def test_concat_batches(ray_start_regular_shared):
  119. workers = make_workers(0)
  120. a = ParallelRollouts(workers, mode="async")
  121. b = a.combine(ConcatBatches(1000))
  122. assert next(b).count == 1000
  123. timers = b.shared_metrics.get().timers
  124. assert "sample" in timers
  125. def test_standardize(ray_start_regular_shared):
  126. workers = make_workers(0)
  127. a = ParallelRollouts(workers, mode="async")
  128. b = a.for_each(StandardizeFields([SampleBatch.EPS_ID]))
  129. batch = next(b)
  130. assert abs(np.mean(batch[SampleBatch.EPS_ID])) < 0.001, batch
  131. assert abs(np.std(batch[SampleBatch.EPS_ID]) - 1.0) < 0.001, batch
  132. def test_async_grads(ray_start_regular_shared):
  133. workers = make_workers(2)
  134. a = AsyncGradients(workers)
  135. res1 = next(a)
  136. assert isinstance(res1, tuple) and len(res1) == 2, res1
  137. counters = a.shared_metrics.get().counters
  138. assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
  139. workers.stop()
  140. def test_train_one_step(ray_start_regular_shared):
  141. workers = make_workers(0)
  142. a = ParallelRollouts(workers, mode="bulk_sync")
  143. b = a.for_each(TrainOneStep(workers))
  144. batch, stats = next(b)
  145. assert isinstance(batch, SampleBatch)
  146. assert DEFAULT_POLICY_ID in stats
  147. assert "learner_stats" in stats[DEFAULT_POLICY_ID]
  148. counters = a.shared_metrics.get().counters
  149. assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
  150. assert counters[STEPS_TRAINED_COUNTER] == 100, counters
  151. timers = a.shared_metrics.get().timers
  152. assert "learn" in timers
  153. workers.stop()
  154. def test_compute_gradients(ray_start_regular_shared):
  155. workers = make_workers(0)
  156. a = ParallelRollouts(workers, mode="bulk_sync")
  157. b = a.for_each(ComputeGradients(workers))
  158. grads, counts = next(b)
  159. assert counts == 100, counts
  160. timers = a.shared_metrics.get().timers
  161. assert "compute_grads" in timers
  162. def test_avg_gradients(ray_start_regular_shared):
  163. workers = make_workers(0)
  164. a = ParallelRollouts(workers, mode="bulk_sync")
  165. b = a.for_each(ComputeGradients(workers)).batch(4)
  166. c = b.for_each(AverageGradients())
  167. grads, counts = next(c)
  168. assert counts == 400, counts
  169. def test_store_to_replay_local(ray_start_regular_shared):
  170. buf = LocalReplayBuffer(
  171. num_shards=1,
  172. learning_starts=200,
  173. capacity=1000,
  174. replay_batch_size=100,
  175. prioritized_replay_alpha=0.6,
  176. prioritized_replay_beta=0.4,
  177. prioritized_replay_eps=0.0001)
  178. assert buf.replay() is None
  179. workers = make_workers(0)
  180. a = ParallelRollouts(workers, mode="bulk_sync")
  181. b = a.for_each(StoreToReplayBuffer(local_buffer=buf))
  182. next(b)
  183. assert buf.replay() is None # learning hasn't started yet
  184. next(b)
  185. assert buf.replay().count == 100
  186. replay_op = Replay(local_buffer=buf)
  187. assert next(replay_op).count == 100
  188. def test_store_to_replay_actor(ray_start_regular_shared):
  189. actor = ReplayActor.remote(
  190. num_shards=1,
  191. learning_starts=200,
  192. buffer_size=1000,
  193. replay_batch_size=100,
  194. prioritized_replay_alpha=0.6,
  195. prioritized_replay_beta=0.4,
  196. prioritized_replay_eps=0.0001)
  197. assert ray.get(actor.replay.remote()) is None
  198. workers = make_workers(0)
  199. a = ParallelRollouts(workers, mode="bulk_sync")
  200. b = a.for_each(StoreToReplayBuffer(actors=[actor]))
  201. next(b)
  202. assert ray.get(actor.replay.remote()) is None # learning hasn't started
  203. next(b)
  204. assert ray.get(actor.replay.remote()).count == 100
  205. replay_op = Replay(actors=[actor])
  206. assert next(replay_op).count == 100
  207. if __name__ == "__main__":
  208. import pytest
  209. import sys
  210. sys.exit(pytest.main(["-v", __file__]))