1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- import os
- import sys
- import unittest
- import pytest
- import ray
- from ray import air
- from ray import tune
- import ray.rllib.algorithms.ppo as ppo
- from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
- from ray.util.client.ray_client_helpers import ray_start_client_server
- class TestRayClient(unittest.TestCase):
- def test_connection(self):
- with ray_start_client_server():
- assert ray.util.client.ray.is_connected()
- assert ray.util.client.ray.is_connected() is False
- def test_custom_train_fn(self):
- with ray_start_client_server():
- assert ray.util.client.ray.is_connected()
- config = {
- # Special flag signalling `my_train_fn` how many iters to do.
- "train-iterations": 2,
- "lr": 0.01,
- # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
- "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
- "num_workers": 0,
- "framework": "tf",
- }
- resources = ppo.PPO.default_resource_request(config)
- from ray.rllib.examples.custom_train_fn import my_train_fn
- tune.Tuner(
- tune.with_resources(my_train_fn, resources),
- param_space=config,
- ).fit()
- def test_cartpole_lstm(self):
- with ray_start_client_server():
- assert ray.util.client.ray.is_connected()
- config = {
- "env": StatelessCartPole,
- }
- stop = {
- "training_iteration": 3,
- }
- tune.Tuner(
- "PPO",
- param_space=config,
- run_config=air.RunConfig(stop=stop, verbose=2),
- ).fit()
- def test_custom_experiment(self):
- with ray_start_client_server(ray_init_kwargs={"num_cpus": 3}):
- assert ray.util.client.ray.is_connected()
- config = ppo.PPOConfig().environment("CartPole-v1")
- # Special flag signalling `experiment` how many iters to do.
- config = config.to_dict()
- config["train-iterations"] = 2
- from ray.rllib.examples.custom_experiment import experiment
- # Ray client does not seem to propagate the `fn._resources` property
- # correctly for imported functions. As a workaround, we can wrap the
- # imported function which forces a full transfer.
- def wrapped_experiment(config):
- experiment(config)
- tune.Tuner(
- tune.with_resources(
- wrapped_experiment, ppo.PPO.default_resource_request(config)
- ),
- param_space=config,
- ).fit()
- if __name__ == "__main__":
- sys.exit(pytest.main(["-v", __file__]))
|