test_minibatch_utils.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import unittest
  2. import numpy as np
  3. from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
  4. from ray.rllib.utils.minibatch_utils import MiniBatchCyclicIterator
  5. CONFIGS = [
  6. {"mini_batch_size": 128, "num_sgd_iter": 3, "agent_steps": (56, 56)},
  7. {"mini_batch_size": 128, "num_sgd_iter": 7, "agent_steps": (56, 56)},
  8. {"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (56, 56)},
  9. {"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (56, 3)},
  10. {"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (56, 4)},
  11. {"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (56, 55)},
  12. {"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (400, 400)},
  13. {"mini_batch_size": 128, "num_sgd_iter": 10, "agent_steps": (64, 64)},
  14. ]
  15. class TestMinibatchUtils(unittest.TestCase):
  16. def test_minibatch_cyclic_iterator(self):
  17. for config in CONFIGS:
  18. mini_batch_size = config["mini_batch_size"]
  19. num_sgd_iter = config["num_sgd_iter"]
  20. agent_steps = config["agent_steps"]
  21. num_env_steps = max(agent_steps)
  22. sample_batches = {
  23. f"pol{i}": SampleBatch({"obs": np.arange(agent_steps[i])})
  24. for i in range(len(agent_steps))
  25. }
  26. mb = MultiAgentBatch(sample_batches, num_env_steps)
  27. batch_iter = MiniBatchCyclicIterator(mb, mini_batch_size, num_sgd_iter)
  28. print(config)
  29. iteration_counter = 0
  30. for batch in batch_iter:
  31. print(batch)
  32. print("-" * 80)
  33. print(batch["pol0"]["obs"])
  34. print("*" * 80)
  35. # Check that for each policy the batch size is equal to the
  36. # mini_batch_size
  37. for policy_batch in batch.policy_batches.values():
  38. self.assertEqual(policy_batch.count, mini_batch_size)
  39. iteration_counter += 1
  40. # for each policy check that the last item in batch matches the expected
  41. # values, i.e. iteration_counter * mini_batch_size % agent_steps - 1
  42. total_steps = iteration_counter * mini_batch_size
  43. for policy_idx, policy_batch in enumerate(batch.policy_batches.values()):
  44. expected_last_item = (total_steps - 1) % agent_steps[policy_idx]
  45. self.assertEqual(policy_batch["obs"][-1], expected_last_item)
  46. # check iteration counter (should be
  47. # ceil(num_gsd_iter * max(agent_steps) / mini_batch_size))
  48. expected_iteration_counter = np.ceil(
  49. num_sgd_iter * max(agent_steps) / mini_batch_size
  50. )
  51. self.assertEqual(iteration_counter, expected_iteration_counter)
  52. print(f"iteration_counter: {iteration_counter}")
  53. if __name__ == "__main__":
  54. import pytest
  55. import sys
  56. sys.exit(pytest.main(["-v", __file__]))