test_errors.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import unittest
  2. import ray
  3. import ray.rllib.agents.impala as impala
  4. import ray.rllib.agents.pg as pg
  5. from ray.rllib.utils.error import EnvError
  6. from ray.rllib.utils.test_utils import framework_iterator
  7. class TestErrors(unittest.TestCase):
  8. """Tests various failure-modes, making sure we produce meaningful errmsgs.
  9. """
  10. @classmethod
  11. def setUpClass(cls) -> None:
  12. ray.init()
  13. @classmethod
  14. def tearDownClass(cls) -> None:
  15. ray.shutdown()
  16. def test_no_gpus_error(self):
  17. """Tests errors related to no-GPU/too-few GPUs/etc.
  18. This test will only work ok on a CPU-only machine.
  19. """
  20. config = impala.DEFAULT_CONFIG.copy()
  21. env = "CartPole-v0"
  22. for _ in framework_iterator(config):
  23. self.assertRaisesRegex(
  24. RuntimeError,
  25. # (?s): "dot matches all" (also newlines).
  26. "(?s)Found 0 GPUs on your machine.+To change the config",
  27. lambda: impala.ImpalaTrainer(config=config, env=env),
  28. )
  29. def test_bad_envs(self):
  30. """Tests different "bad env" errors.
  31. """
  32. config = pg.DEFAULT_CONFIG.copy()
  33. config["num_workers"] = 0
  34. # Non existing/non-registered gym env string.
  35. env = "Alien-Attack-v42"
  36. for _ in framework_iterator(config):
  37. self.assertRaisesRegex(
  38. EnvError,
  39. f"The env string you provided \\('{env}'\\) is",
  40. lambda: pg.PGTrainer(config=config, env=env),
  41. )
  42. # Malformed gym env string (must have v\d at end).
  43. env = "Alien-Attack-part-42"
  44. for _ in framework_iterator(config):
  45. self.assertRaisesRegex(
  46. EnvError,
  47. f"The env string you provided \\('{env}'\\) is",
  48. lambda: pg.PGTrainer(config=config, env=env),
  49. )
  50. # Non-existing class in a full-class-path.
  51. env = "ray.rllib.examples.env.random_env.RandomEnvThatDoesntExist"
  52. for _ in framework_iterator(config):
  53. self.assertRaisesRegex(
  54. EnvError,
  55. f"The env string you provided \\('{env}'\\) is",
  56. lambda: pg.PGTrainer(config=config, env=env),
  57. )
  58. # Non-existing module inside a full-class-path.
  59. env = "ray.rllib.examples.env.module_that_doesnt_exist.SomeEnv"
  60. for _ in framework_iterator(config):
  61. self.assertRaisesRegex(
  62. EnvError,
  63. f"The env string you provided \\('{env}'\\) is",
  64. lambda: pg.PGTrainer(config=config, env=env),
  65. )
  66. if __name__ == "__main__":
  67. import pytest
  68. import sys
  69. sys.exit(pytest.main(["-v", __file__]))