12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- import unittest
- import ray
- import ray.rllib.agents.impala as impala
- import ray.rllib.agents.pg as pg
- from ray.rllib.utils.error import EnvError
- from ray.rllib.utils.test_utils import framework_iterator
- class TestErrors(unittest.TestCase):
- """Tests various failure-modes, making sure we produce meaningful errmsgs.
- """
- @classmethod
- def setUpClass(cls) -> None:
- ray.init()
- @classmethod
- def tearDownClass(cls) -> None:
- ray.shutdown()
- def test_no_gpus_error(self):
- """Tests errors related to no-GPU/too-few GPUs/etc.
- This test will only work ok on a CPU-only machine.
- """
- config = impala.DEFAULT_CONFIG.copy()
- env = "CartPole-v0"
- for _ in framework_iterator(config):
- self.assertRaisesRegex(
- RuntimeError,
- # (?s): "dot matches all" (also newlines).
- "(?s)Found 0 GPUs on your machine.+To change the config",
- lambda: impala.ImpalaTrainer(config=config, env=env),
- )
- def test_bad_envs(self):
- """Tests different "bad env" errors.
- """
- config = pg.DEFAULT_CONFIG.copy()
- config["num_workers"] = 0
- # Non existing/non-registered gym env string.
- env = "Alien-Attack-v42"
- for _ in framework_iterator(config):
- self.assertRaisesRegex(
- EnvError,
- f"The env string you provided \\('{env}'\\) is",
- lambda: pg.PGTrainer(config=config, env=env),
- )
- # Malformed gym env string (must have v\d at end).
- env = "Alien-Attack-part-42"
- for _ in framework_iterator(config):
- self.assertRaisesRegex(
- EnvError,
- f"The env string you provided \\('{env}'\\) is",
- lambda: pg.PGTrainer(config=config, env=env),
- )
- # Non-existing class in a full-class-path.
- env = "ray.rllib.examples.env.random_env.RandomEnvThatDoesntExist"
- for _ in framework_iterator(config):
- self.assertRaisesRegex(
- EnvError,
- f"The env string you provided \\('{env}'\\) is",
- lambda: pg.PGTrainer(config=config, env=env),
- )
- # Non-existing module inside a full-class-path.
- env = "ray.rllib.examples.env.module_that_doesnt_exist.SomeEnv"
- for _ in framework_iterator(config):
- self.assertRaisesRegex(
- EnvError,
- f"The env string you provided \\('{env}'\\) is",
- lambda: pg.PGTrainer(config=config, env=env),
- )
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|