test_reproducibility.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import gym
  2. import numpy as np
  3. import unittest
  4. import ray
  5. from ray.rllib.agents.dqn import DQNTrainer
  6. from ray.rllib.utils.test_utils import framework_iterator
  7. from ray.tune.registry import register_env
  8. class TestReproducibility(unittest.TestCase):
  9. def test_reproducing_trajectory(self):
  10. class PickLargest(gym.Env):
  11. def __init__(self):
  12. self.observation_space = gym.spaces.Box(
  13. low=float("-inf"), high=float("inf"), shape=(4, ))
  14. self.action_space = gym.spaces.Discrete(4)
  15. def reset(self, **kwargs):
  16. self.obs = np.random.randn(4)
  17. return self.obs
  18. def step(self, action):
  19. reward = self.obs[action]
  20. return self.obs, reward, True, {}
  21. def env_creator(env_config):
  22. return PickLargest()
  23. for fw in framework_iterator(frameworks=("tf", "torch")):
  24. trajs = list()
  25. for trial in range(3):
  26. ray.init()
  27. register_env("PickLargest", env_creator)
  28. config = {
  29. "seed": 666 if trial in [0, 1] else 999,
  30. "min_iter_time_s": 0,
  31. "timesteps_per_iteration": 100,
  32. "framework": fw,
  33. }
  34. agent = DQNTrainer(config=config, env="PickLargest")
  35. trajectory = list()
  36. for _ in range(8):
  37. r = agent.train()
  38. trajectory.append(r["episode_reward_max"])
  39. trajectory.append(r["episode_reward_min"])
  40. trajs.append(trajectory)
  41. ray.shutdown()
  42. # trial0 and trial1 use same seed and thus
  43. # expect identical trajectories.
  44. all_same = True
  45. for v0, v1 in zip(trajs[0], trajs[1]):
  46. if v0 != v1:
  47. all_same = False
  48. self.assertTrue(all_same)
  49. # trial1 and trial2 use different seeds and thus
  50. # most rewards tend to be different.
  51. diff_cnt = 0
  52. for v1, v2 in zip(trajs[1], trajs[2]):
  53. if v1 != v2:
  54. diff_cnt += 1
  55. self.assertTrue(diff_cnt > 8)
  56. if __name__ == "__main__":
  57. import pytest
  58. import sys
  59. sys.exit(pytest.main(["-v", __file__]))