custom_train_fn.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. """Example of a custom training workflow. Run this for a demo.
  2. This example shows:
  3. - using Tune trainable functions to implement custom training workflows
  4. You can visualize experiment results in ~/ray_results using TensorBoard.
  5. """
  6. import argparse
  7. import os
  8. import ray
  9. from ray import train, tune
  10. from ray.rllib.algorithms.ppo import PPO, PPOConfig
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument(
  13. "--framework",
  14. choices=["tf", "tf2", "torch"],
  15. default="torch",
  16. help="The DL framework specifier.",
  17. )
  18. def my_train_fn(config):
  19. iterations = config.pop("train-iterations", 10)
  20. config = PPOConfig().update_from_dict(config).environment("CartPole-v1")
  21. # Train for n iterations with high LR.
  22. config.lr = 0.01
  23. agent1 = config.build()
  24. for _ in range(iterations):
  25. result = agent1.train()
  26. result["phase"] = 1
  27. train.report(result)
  28. phase1_time = result["timesteps_total"]
  29. state = agent1.save()
  30. agent1.stop()
  31. # Train for n iterations with low LR
  32. config.lr = 0.0001
  33. agent2 = config.build()
  34. agent2.restore(state)
  35. for _ in range(iterations):
  36. result = agent2.train()
  37. result["phase"] = 2
  38. result["timesteps_total"] += phase1_time # keep time moving forward
  39. train.report(result)
  40. agent2.stop()
  41. if __name__ == "__main__":
  42. ray.init()
  43. args = parser.parse_args()
  44. config = {
  45. # Special flag signalling `my_train_fn` how many iters to do.
  46. "train-iterations": 2,
  47. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  48. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  49. "num_workers": 0,
  50. "framework": args.framework,
  51. }
  52. resources = PPO.default_resource_request(config)
  53. tuner = tune.Tuner(
  54. tune.with_resources(my_train_fn, resources=resources), param_space=config
  55. )
  56. tuner.fit()