test_envs_that_crash.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import unittest
  2. import ray
  3. from ray.rllib.algorithms.pg import pg
  4. from ray.rllib.algorithms.tests.test_worker_failures import (
  5. ForwardHealthCheckToEnvWorker,
  6. )
  7. from ray.rllib.examples.env.cartpole_crashing import CartPoleCrashing
  8. from ray.rllib.utils.error import EnvError
  9. class TestEnvsThatCrash(unittest.TestCase):
  10. @classmethod
  11. def setUpClass(cls) -> None:
  12. ray.init(num_cpus=4)
  13. @classmethod
  14. def tearDownClass(cls) -> None:
  15. ray.shutdown()
  16. def test_crash_during_env_pre_checking(self):
  17. """Expect the env pre-checking to fail on each worker."""
  18. config = (
  19. pg.PGConfig()
  20. .rollouts(num_rollout_workers=2, num_envs_per_worker=4)
  21. .environment(
  22. env=CartPoleCrashing,
  23. env_config={
  24. # Crash prob=100% (during pre-checking's `step()` test calls).
  25. "p_crash": 1.0,
  26. "init_time_s": 0.5,
  27. },
  28. )
  29. )
  30. # Expect ValueError due to pre-checking failing (our pre-checker module
  31. # raises a ValueError if `step()` fails).
  32. self.assertRaisesRegex(
  33. ValueError,
  34. "Simulated env crash",
  35. lambda: config.build(),
  36. )
  37. def test_crash_during_sampling(self):
  38. """Expect some sub-envs to fail (and not recover)."""
  39. config = (
  40. pg.PGConfig()
  41. .rollouts(num_rollout_workers=2, num_envs_per_worker=3)
  42. .environment(
  43. env=CartPoleCrashing,
  44. env_config={
  45. # Crash prob=20%.
  46. "p_crash": 0.2,
  47. "init_time_s": 0.3,
  48. # Make sure nothing happens during pre-checks.
  49. "skip_env_checking": True,
  50. },
  51. )
  52. )
  53. # Pre-checking disables, so building the Algorithm is save.
  54. algo = config.build()
  55. # Expect EnvError due to the sub-env(s) crashing on the different workers
  56. # and `ignore_worker_failures=False` (so the original EnvError should
  57. # just be bubbled up by RLlib Algorithm and tune.Trainable during the `step()`
  58. # call).
  59. self.assertRaisesRegex(EnvError, "Simulated env crash", lambda: algo.train())
  60. def test_crash_only_one_worker_during_sampling_but_ignore(self):
  61. """Expect some sub-envs to fail (and not recover), but ignore."""
  62. config = (
  63. pg.PGConfig()
  64. .rollouts(
  65. env_runner_cls=ForwardHealthCheckToEnvWorker,
  66. num_rollout_workers=2,
  67. num_envs_per_worker=3,
  68. # Ignore worker failures (continue with worker #2).
  69. ignore_worker_failures=True,
  70. )
  71. .environment(
  72. env=CartPoleCrashing,
  73. env_config={
  74. # Crash prob=80%.
  75. "p_crash": 0.8,
  76. # Only crash on worker with index 1.
  77. "crash_on_worker_indices": [1],
  78. # Make sure nothing happens during pre-checks.
  79. "skip_env_checking": True,
  80. },
  81. )
  82. )
  83. # Pre-checking disables, so building the Algorithm is save.
  84. algo = config.build()
  85. # Expect some errors being logged here, but in general, should continue
  86. # as we ignore worker failures.
  87. algo.train()
  88. # One worker has been removed -> Only one left.
  89. self.assertEqual(algo.workers.num_healthy_remote_workers(), 1)
  90. algo.stop()
  91. def test_crash_only_one_worker_during_sampling_but_recreate(self):
  92. """Expect some sub-envs to fail (and not recover), but re-create worker."""
  93. config = (
  94. pg.PGConfig()
  95. .rollouts(
  96. env_runner_cls=ForwardHealthCheckToEnvWorker,
  97. num_rollout_workers=2,
  98. rollout_fragment_length=10,
  99. num_envs_per_worker=3,
  100. # Re-create failed workers (then continue).
  101. recreate_failed_workers=True,
  102. )
  103. .training(train_batch_size=60)
  104. .environment(
  105. env=CartPoleCrashing,
  106. env_config={
  107. "crash_after_n_steps": 10,
  108. # Crash prob=100%, so test is deterministic.
  109. "p_crash": 1.0,
  110. # Only crash on worker with index 2.
  111. "crash_on_worker_indices": [2],
  112. # Make sure nothing happens during pre-checks.
  113. "skip_env_checking": True,
  114. },
  115. )
  116. )
  117. # Pre-checking disables, so building the Algorithm is save.
  118. algo = config.build()
  119. # Try to re-create for infinite amount of times.
  120. # The worker recreation/ignore tolerance used to be hard-coded to 3, but this
  121. # has now been
  122. for _ in range(10):
  123. # Expect some errors being logged here, but in general, should continue
  124. # as we recover from all worker failures.
  125. algo.train()
  126. # One worker has been removed.
  127. self.assertEqual(algo.workers.num_healthy_remote_workers(), 1)
  128. algo.stop()
  129. def test_crash_sub_envs_during_sampling_but_restart_sub_envs(self):
  130. """Expect sub-envs to fail (and not recover), but re-start them individually."""
  131. config = (
  132. pg.PGConfig()
  133. .rollouts(
  134. num_rollout_workers=2,
  135. num_envs_per_worker=3,
  136. # Re-start failed individual sub-envs (then continue).
  137. # This means no workers will ever fail due to individual env errors
  138. # (only maybe for reasons other than the env).
  139. restart_failed_sub_environments=True,
  140. # If the worker was affected by an error (other than the env error),
  141. # allow it to be removed, but training will continue.
  142. ignore_worker_failures=True,
  143. )
  144. .environment(
  145. env=CartPoleCrashing,
  146. env_config={
  147. # Crash prob=1%.
  148. "p_crash": 0.01,
  149. # Make sure nothing happens during pre-checks.
  150. "skip_env_checking": True,
  151. },
  152. )
  153. )
  154. # Pre-checking disables, so building the Algorithm is save.
  155. algo = config.build()
  156. # Try to re-create the sub-env for infinite amount of times.
  157. # The worker recreation/ignore tolerance used to be hard-coded to 3, but this
  158. # has now been
  159. for _ in range(10):
  160. # Expect some errors being logged here, but in general, should continue
  161. # as we recover from all sub-env failures.
  162. algo.train()
  163. # No worker has been removed. Still 2 left.
  164. self.assertEqual(algo.workers.num_healthy_remote_workers(), 2)
  165. algo.stop()
  166. if __name__ == "__main__":
  167. import pytest
  168. import sys
  169. sys.exit(pytest.main(["-v", __file__]))