test_reproducibility.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import gymnasium as gym
  2. import numpy as np
  3. import unittest
  4. import ray
  5. from ray.rllib.algorithms.dqn import DQNConfig
  6. from ray.rllib.utils.test_utils import framework_iterator
  7. from ray.tune.registry import register_env
  8. class TestReproducibility(unittest.TestCase):
  9. def test_reproducing_trajectory(self):
  10. class PickLargest(gym.Env):
  11. def __init__(self):
  12. self.observation_space = gym.spaces.Box(
  13. low=float("-inf"), high=float("inf"), shape=(4,)
  14. )
  15. self.action_space = gym.spaces.Discrete(4)
  16. def reset(self, *, seed=None, options=None):
  17. self.obs = np.random.randn(4)
  18. return self.obs, {}
  19. def step(self, action):
  20. reward = self.obs[action]
  21. return self.obs, reward, True, False, {}
  22. def env_creator(env_config):
  23. return PickLargest()
  24. for fw in framework_iterator(frameworks=("tf", "torch")):
  25. trajs = list()
  26. for trial in range(3):
  27. ray.init()
  28. register_env("PickLargest", env_creator)
  29. config = (
  30. DQNConfig()
  31. .environment("PickLargest")
  32. .debugging(seed=666 if trial in [0, 1] else 999)
  33. .reporting(
  34. min_time_s_per_iteration=0,
  35. min_sample_timesteps_per_iteration=100,
  36. )
  37. .framework(fw)
  38. )
  39. algo = config.build()
  40. trajectory = list()
  41. for _ in range(8):
  42. r = algo.train()
  43. trajectory.append(r["episode_reward_max"])
  44. trajectory.append(r["episode_reward_min"])
  45. trajs.append(trajectory)
  46. algo.stop()
  47. ray.shutdown()
  48. # trial0 and trial1 use same seed and thus
  49. # expect identical trajectories.
  50. all_same = True
  51. for v0, v1 in zip(trajs[0], trajs[1]):
  52. if v0 != v1:
  53. all_same = False
  54. self.assertTrue(all_same)
  55. # trial1 and trial2 use different seeds and thus
  56. # most rewards tend to be different.
  57. diff_cnt = 0
  58. for v1, v2 in zip(trajs[1], trajs[2]):
  59. if v1 != v2:
  60. diff_cnt += 1
  61. self.assertTrue(diff_cnt > 8)
  62. if __name__ == "__main__":
  63. import pytest
  64. import sys
  65. sys.exit(pytest.main(["-v", __file__]))