test_custom_resource.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import pytest
  2. import ray
  3. from ray import air
  4. from ray import tune
  5. from ray.tune.registry import get_trainable_cls
  6. @pytest.mark.parametrize("algorithm", ["PPO", "APEX", "IMPALA"])
  7. def test_custom_resource(algorithm):
  8. if ray.is_initialized:
  9. ray.shutdown()
  10. ray.init(
  11. resources={"custom_resource": 1},
  12. include_dashboard=False,
  13. )
  14. config = (
  15. get_trainable_cls(algorithm)
  16. .get_default_config()
  17. .environment("CartPole-v1")
  18. .framework("torch")
  19. .rollouts(num_rollout_workers=1)
  20. .resources(num_gpus=0, custom_resources_per_worker={"custom_resource": 0.01})
  21. )
  22. if algorithm == "APEX":
  23. config.num_steps_sampled_before_learning_starts = 0
  24. stop = {"training_iteration": 1}
  25. tune.Tuner(
  26. algorithm,
  27. param_space=config,
  28. run_config=air.RunConfig(stop=stop, verbose=0),
  29. tune_config=tune.TuneConfig(num_samples=1),
  30. ).fit()
  31. if __name__ == "__main__":
  32. import sys
  33. sys.exit(pytest.main(["-v", __file__]))