test_filters.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import numpy as np
  2. import unittest
  3. import ray
  4. from ray.rllib.utils.filter import RunningStat, MeanStdFilter
  5. from ray.rllib.utils import FilterManager
  6. from ray.rllib.tests.mock_worker import _MockWorker
  7. class RunningStatTest(unittest.TestCase):
  8. def testRunningStat(self):
  9. for shp in ((), (3, ), (3, 4)):
  10. li = []
  11. rs = RunningStat(shp)
  12. for _ in range(5):
  13. val = np.random.randn(*shp)
  14. rs.push(val)
  15. li.append(val)
  16. m = np.mean(li, axis=0)
  17. self.assertTrue(np.allclose(rs.mean, m))
  18. v = (np.square(m)
  19. if (len(li) == 1) else np.var(li, ddof=1, axis=0))
  20. self.assertTrue(np.allclose(rs.var, v))
  21. def testCombiningStat(self):
  22. for shape in [(), (3, ), (3, 4)]:
  23. li = []
  24. rs1 = RunningStat(shape)
  25. rs2 = RunningStat(shape)
  26. rs = RunningStat(shape)
  27. for _ in range(5):
  28. val = np.random.randn(*shape)
  29. rs1.push(val)
  30. rs.push(val)
  31. li.append(val)
  32. for _ in range(9):
  33. rs2.push(val)
  34. rs.push(val)
  35. li.append(val)
  36. rs1.update(rs2)
  37. assert np.allclose(rs.mean, rs1.mean)
  38. assert np.allclose(rs.std, rs1.std)
  39. class MSFTest(unittest.TestCase):
  40. def testBasic(self):
  41. for shape in [(), (3, ), (3, 4, 4)]:
  42. filt = MeanStdFilter(shape)
  43. for i in range(5):
  44. filt(np.ones(shape))
  45. self.assertEqual(filt.rs.n, 5)
  46. self.assertEqual(filt.buffer.n, 5)
  47. filt2 = MeanStdFilter(shape)
  48. filt2.sync(filt)
  49. self.assertEqual(filt2.rs.n, 5)
  50. self.assertEqual(filt2.buffer.n, 5)
  51. filt.clear_buffer()
  52. self.assertEqual(filt.buffer.n, 0)
  53. self.assertEqual(filt2.buffer.n, 5)
  54. filt.apply_changes(filt2, with_buffer=False)
  55. self.assertEqual(filt.buffer.n, 0)
  56. self.assertEqual(filt.rs.n, 10)
  57. filt.apply_changes(filt2, with_buffer=True)
  58. self.assertEqual(filt.buffer.n, 5)
  59. self.assertEqual(filt.rs.n, 15)
  60. class FilterManagerTest(unittest.TestCase):
  61. def setUp(self):
  62. ray.init(
  63. num_cpus=1,
  64. object_store_memory=1000 * 1024 * 1024,
  65. ignore_reinit_error=True)
  66. def tearDown(self):
  67. ray.shutdown()
  68. def test_synchronize(self):
  69. """Synchronize applies filter buffer onto own filter"""
  70. filt1 = MeanStdFilter(())
  71. for i in range(10):
  72. filt1(i)
  73. self.assertEqual(filt1.rs.n, 10)
  74. filt1.clear_buffer()
  75. self.assertEqual(filt1.buffer.n, 0)
  76. RemoteWorker = ray.remote(_MockWorker)
  77. remote_e = RemoteWorker.remote(sample_count=10)
  78. remote_e.sample.remote()
  79. FilterManager.synchronize({
  80. "obs_filter": filt1,
  81. "rew_filter": filt1.copy()
  82. }, [remote_e])
  83. filters = ray.get(remote_e.get_filters.remote())
  84. obs_f = filters["obs_filter"]
  85. self.assertEqual(filt1.rs.n, 20)
  86. self.assertEqual(filt1.buffer.n, 0)
  87. self.assertEqual(obs_f.rs.n, filt1.rs.n)
  88. self.assertEqual(obs_f.buffer.n, filt1.buffer.n)
  89. if __name__ == "__main__":
  90. import pytest
  91. import sys
  92. sys.exit(pytest.main(["-v", __file__]))