test_ignore_worker_failure.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import gym
  2. import unittest
  3. import ray
  4. from ray.rllib import _register_all
  5. from ray.rllib.agents.registry import get_trainer_class
  6. from ray.rllib.utils.test_utils import framework_iterator
  7. from ray.tune.registry import register_env
  8. class FaultInjectEnv(gym.Env):
  9. """Env that fails upon calling `step()`, but only for some remote workers.
  10. The worker indices that should produce the failure (a ValueError) can be
  11. provided by a list (of ints) under the "bad_indices" key in the env's
  12. config.
  13. Examples:
  14. >>> from ray.rllib.env.env_context import EnvContext
  15. >>> # This env will fail for workers 1 and 2 (not for the local worker
  16. >>> # or any others with an index > 2).
  17. >>> bad_env = FaultInjectEnv(
  18. ... EnvContext({"bad_indices": [1, 2]},
  19. ... worker_index=1, num_workers=3))
  20. """
  21. def __init__(self, config):
  22. self.env = gym.make("CartPole-v0")
  23. self.action_space = self.env.action_space
  24. self.observation_space = self.env.observation_space
  25. self.config = config
  26. def reset(self):
  27. return self.env.reset()
  28. def step(self, action):
  29. if self.config.worker_index in self.config["bad_indices"]:
  30. raise ValueError("This is a simulated error from "
  31. f"worker-idx={self.config.worker_index}.")
  32. return self.env.step(action)
  33. class IgnoresWorkerFailure(unittest.TestCase):
  34. @classmethod
  35. def setUpClass(cls) -> None:
  36. ray.init()
  37. @classmethod
  38. def tearDownClass(cls) -> None:
  39. ray.shutdown()
  40. def do_test(self, alg, config, fn=None):
  41. fn = fn or self._do_test_fault_recover
  42. try:
  43. ray.init(num_cpus=6, ignore_reinit_error=True)
  44. fn(alg, config)
  45. finally:
  46. ray.shutdown()
  47. _register_all() # re-register the evicted objects
  48. def _do_test_fault_recover(self, alg, config):
  49. register_env("fault_env", lambda c: FaultInjectEnv(c))
  50. agent_cls = get_trainer_class(alg)
  51. # Test fault handling
  52. config["num_workers"] = 2
  53. config["ignore_worker_failures"] = True
  54. # Make worker idx=1 fail. Other workers will be ok.
  55. config["env_config"] = {"bad_indices": [1]}
  56. for _ in framework_iterator(config, frameworks=("torch", "tf")):
  57. a = agent_cls(config=config, env="fault_env")
  58. result = a.train()
  59. self.assertTrue(result["num_healthy_workers"], 1)
  60. a.stop()
  61. def _do_test_fault_fatal(self, alg, config):
  62. register_env("fault_env", lambda c: FaultInjectEnv(c))
  63. agent_cls = get_trainer_class(alg)
  64. # Test raises real error when out of workers
  65. config["num_workers"] = 2
  66. config["ignore_worker_failures"] = True
  67. # Make both worker idx=1 and 2 fail.
  68. config["env_config"] = {"bad_indices": [1, 2]}
  69. for _ in framework_iterator(config, frameworks=("torch", "tf")):
  70. a = agent_cls(config=config, env="fault_env")
  71. self.assertRaises(Exception, lambda: a.train())
  72. a.stop()
  73. def test_fatal(self):
  74. # test the case where all workers fail
  75. self.do_test("PG", {"optimizer": {}}, fn=self._do_test_fault_fatal)
  76. def test_async_grads(self):
  77. self.do_test("A3C", {"optimizer": {"grads_per_step": 1}})
  78. def test_async_replay(self):
  79. self.do_test(
  80. "APEX", {
  81. "timesteps_per_iteration": 1000,
  82. "num_gpus": 0,
  83. "min_time_s_per_reporting": 1,
  84. "explore": False,
  85. "learning_starts": 1000,
  86. "target_network_update_freq": 100,
  87. "optimizer": {
  88. "num_replay_buffer_shards": 1,
  89. },
  90. })
  91. def test_async_samples(self):
  92. self.do_test("IMPALA", {"num_gpus": 0})
  93. def test_sync_replay(self):
  94. self.do_test("DQN", {"timesteps_per_iteration": 1})
  95. def test_multi_g_p_u(self):
  96. self.do_test(
  97. "PPO", {
  98. "num_sgd_iter": 1,
  99. "train_batch_size": 10,
  100. "rollout_fragment_length": 10,
  101. "sgd_minibatch_size": 1,
  102. })
  103. def test_sync_samples(self):
  104. self.do_test("PG", {"optimizer": {}})
  105. def test_async_sampling_option(self):
  106. self.do_test("PG", {"optimizer": {}, "sample_async": True})
  107. if __name__ == "__main__":
  108. import pytest
  109. import sys
  110. sys.exit(pytest.main(["-v", __file__]))