test_ray_client.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import os
  2. import sys
  3. import unittest
  4. import pytest
  5. import ray
  6. from ray import air
  7. from ray import tune
  8. import ray.rllib.algorithms.ppo as ppo
  9. from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
  10. from ray.util.client.ray_client_helpers import ray_start_client_server
  11. class TestRayClient(unittest.TestCase):
  12. def test_connection(self):
  13. with ray_start_client_server():
  14. assert ray.util.client.ray.is_connected()
  15. assert ray.util.client.ray.is_connected() is False
  16. def test_custom_train_fn(self):
  17. with ray_start_client_server():
  18. assert ray.util.client.ray.is_connected()
  19. config = {
  20. # Special flag signalling `my_train_fn` how many iters to do.
  21. "train-iterations": 2,
  22. "lr": 0.01,
  23. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  24. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  25. "num_workers": 0,
  26. "framework": "tf",
  27. }
  28. resources = ppo.PPO.default_resource_request(config)
  29. from ray.rllib.examples.custom_train_fn import my_train_fn
  30. tune.Tuner(
  31. tune.with_resources(my_train_fn, resources),
  32. param_space=config,
  33. ).fit()
  34. def test_cartpole_lstm(self):
  35. with ray_start_client_server():
  36. assert ray.util.client.ray.is_connected()
  37. config = {
  38. "env": StatelessCartPole,
  39. }
  40. stop = {
  41. "training_iteration": 3,
  42. }
  43. tune.Tuner(
  44. "PPO",
  45. param_space=config,
  46. run_config=air.RunConfig(stop=stop, verbose=2),
  47. ).fit()
  48. def test_custom_experiment(self):
  49. with ray_start_client_server(ray_init_kwargs={"num_cpus": 3}):
  50. assert ray.util.client.ray.is_connected()
  51. config = ppo.PPOConfig().environment("CartPole-v1")
  52. # Special flag signalling `experiment` how many iters to do.
  53. config = config.to_dict()
  54. config["train-iterations"] = 2
  55. from ray.rllib.examples.custom_experiment import experiment
  56. # Ray client does not seem to propagate the `fn._resources` property
  57. # correctly for imported functions. As a workaround, we can wrap the
  58. # imported function which forces a full transfer.
  59. def wrapped_experiment(config):
  60. experiment(config)
  61. tune.Tuner(
  62. tune.with_resources(
  63. wrapped_experiment, ppo.PPO.default_resource_request(config)
  64. ),
  65. param_space=config,
  66. ).fit()
  67. if __name__ == "__main__":
  68. sys.exit(pytest.main(["-v", __file__]))