test_ray_client.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. import sys
  3. import unittest
  4. import pytest
  5. import ray
  6. from ray import tune
  7. from ray.rllib.agents import ppo
  8. from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
  9. from ray.util.client.ray_client_helpers import ray_start_client_server
  10. class TestRayClient(unittest.TestCase):
  11. def test_connection(self):
  12. with ray_start_client_server():
  13. assert ray.util.client.ray.is_connected()
  14. assert ray.util.client.ray.is_connected() is False
  15. def test_custom_train_fn(self):
  16. with ray_start_client_server():
  17. assert ray.util.client.ray.is_connected()
  18. config = {
  19. # Special flag signalling `my_train_fn` how many iters to do.
  20. "train-iterations": 2,
  21. "lr": 0.01,
  22. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  23. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  24. "num_workers": 0,
  25. "framework": "tf",
  26. }
  27. resources = ppo.PPOTrainer.default_resource_request(config)
  28. from ray.rllib.examples.custom_train_fn import my_train_fn
  29. tune.run(my_train_fn, resources_per_trial=resources, config=config)
  30. def test_cartpole_lstm(self):
  31. with ray_start_client_server():
  32. assert ray.util.client.ray.is_connected()
  33. config = {
  34. "env": StatelessCartPole,
  35. }
  36. stop = {
  37. "training_iteration": 3,
  38. }
  39. tune.run("PPO", config=config, stop=stop, verbose=2)
  40. def test_custom_experiment(self):
  41. with ray_start_client_server(ray_init_kwargs={"num_cpus": 3}):
  42. assert ray.util.client.ray.is_connected()
  43. config = ppo.DEFAULT_CONFIG.copy()
  44. # Special flag signalling `experiment` how many iters to do.
  45. config["train-iterations"] = 2
  46. config["env"] = "CartPole-v0"
  47. from ray.rllib.examples.custom_experiment import experiment
  48. tune.run(
  49. experiment,
  50. config=config,
  51. resources_per_trial=ppo.PPOTrainer.default_resource_request(
  52. config))
  53. if __name__ == "__main__":
  54. sys.exit(pytest.main(["-v", __file__]))