test_vector_env.py 700 B

123456789101112131415161718192021222324252627282930313233
  1. import gym
  2. import unittest
  3. from ray.rllib.env.vector_env import VectorEnv
  4. class Info(dict):
  5. pass
  6. class MockEnvDictSubclass(gym.Env):
  7. def __init__(self):
  8. self.observation_space = gym.spaces.Discrete(1)
  9. self.action_space = gym.spaces.Discrete(2)
  10. def reset(self):
  11. return 0
  12. def step(self, action):
  13. return 0, 1, True, Info()
  14. class TestExternalEnv(unittest.TestCase):
  15. def test_vector_step(self):
  16. env = VectorEnv.vectorize_gym_envs(
  17. make_env=lambda _: MockEnvDictSubclass(), num_envs=3)
  18. env.vector_step([0] * 3)
  19. if __name__ == "__main__":
  20. import pytest
  21. import sys
  22. sys.exit(pytest.main(["-v", __file__]))