123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- import gym
- import unittest
- import ray
- from ray.rllib import _register_all
- from ray.rllib.agents.registry import get_trainer_class
- from ray.rllib.utils.test_utils import framework_iterator
- from ray.tune.registry import register_env
- class FaultInjectEnv(gym.Env):
- """Env that fails upon calling `step()`, but only for some remote workers.
- The worker indices that should produce the failure (a ValueError) can be
- provided by a list (of ints) under the "bad_indices" key in the env's
- config.
- Examples:
- >>> from ray.rllib.env.env_context import EnvContext
- >>> # This env will fail for workers 1 and 2 (not for the local worker
- >>> # or any others with an index > 2).
- >>> bad_env = FaultInjectEnv(
- ... EnvContext({"bad_indices": [1, 2]},
- ... worker_index=1, num_workers=3))
- """
- def __init__(self, config):
- self.env = gym.make("CartPole-v0")
- self.action_space = self.env.action_space
- self.observation_space = self.env.observation_space
- self.config = config
- def reset(self):
- return self.env.reset()
- def step(self, action):
- if self.config.worker_index in self.config["bad_indices"]:
- raise ValueError("This is a simulated error from "
- f"worker-idx={self.config.worker_index}.")
- return self.env.step(action)
- class IgnoresWorkerFailure(unittest.TestCase):
- @classmethod
- def setUpClass(cls) -> None:
- ray.init()
- @classmethod
- def tearDownClass(cls) -> None:
- ray.shutdown()
- def do_test(self, alg, config, fn=None):
- fn = fn or self._do_test_fault_recover
- try:
- ray.init(num_cpus=6, ignore_reinit_error=True)
- fn(alg, config)
- finally:
- ray.shutdown()
- _register_all() # re-register the evicted objects
- def _do_test_fault_recover(self, alg, config):
- register_env("fault_env", lambda c: FaultInjectEnv(c))
- agent_cls = get_trainer_class(alg)
- # Test fault handling
- config["num_workers"] = 2
- config["ignore_worker_failures"] = True
- # Make worker idx=1 fail. Other workers will be ok.
- config["env_config"] = {"bad_indices": [1]}
- for _ in framework_iterator(config, frameworks=("torch", "tf")):
- a = agent_cls(config=config, env="fault_env")
- result = a.train()
- self.assertTrue(result["num_healthy_workers"], 1)
- a.stop()
- def _do_test_fault_fatal(self, alg, config):
- register_env("fault_env", lambda c: FaultInjectEnv(c))
- agent_cls = get_trainer_class(alg)
- # Test raises real error when out of workers
- config["num_workers"] = 2
- config["ignore_worker_failures"] = True
- # Make both worker idx=1 and 2 fail.
- config["env_config"] = {"bad_indices": [1, 2]}
- for _ in framework_iterator(config, frameworks=("torch", "tf")):
- a = agent_cls(config=config, env="fault_env")
- self.assertRaises(Exception, lambda: a.train())
- a.stop()
- def test_fatal(self):
- # test the case where all workers fail
- self.do_test("PG", {"optimizer": {}}, fn=self._do_test_fault_fatal)
- def test_async_grads(self):
- self.do_test("A3C", {"optimizer": {"grads_per_step": 1}})
- def test_async_replay(self):
- self.do_test(
- "APEX", {
- "timesteps_per_iteration": 1000,
- "num_gpus": 0,
- "min_time_s_per_reporting": 1,
- "explore": False,
- "learning_starts": 1000,
- "target_network_update_freq": 100,
- "optimizer": {
- "num_replay_buffer_shards": 1,
- },
- })
- def test_async_samples(self):
- self.do_test("IMPALA", {"num_gpus": 0})
- def test_sync_replay(self):
- self.do_test("DQN", {"timesteps_per_iteration": 1})
- def test_multi_g_p_u(self):
- self.do_test(
- "PPO", {
- "num_sgd_iter": 1,
- "train_batch_size": 10,
- "rollout_fragment_length": 10,
- "sgd_minibatch_size": 1,
- })
- def test_sync_samples(self):
- self.do_test("PG", {"optimizer": {}})
- def test_async_sampling_option(self):
- self.do_test("PG", {"optimizer": {}, "sample_async": True})
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|