custom_train_fn.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. """Example of a custom training workflow. Run this for a demo.
  2. This example shows:
  3. - using Tune trainable functions to implement custom training workflows
  4. You can visualize experiment results in ~/ray_results using TensorBoard.
  5. """
  6. import argparse
  7. import os
  8. import ray
  9. from ray import tune
  10. from ray.rllib.agents.ppo import PPOTrainer
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument(
  13. "--framework",
  14. choices=["tf", "tf2", "tfe", "torch"],
  15. default="tf",
  16. help="The DL framework specifier.")
  17. def my_train_fn(config, reporter):
  18. iterations = config.pop("train-iterations", 10)
  19. # Train for n iterations with high LR
  20. agent1 = PPOTrainer(env="CartPole-v0", config=config)
  21. for _ in range(iterations):
  22. result = agent1.train()
  23. result["phase"] = 1
  24. reporter(**result)
  25. phase1_time = result["timesteps_total"]
  26. state = agent1.save()
  27. agent1.stop()
  28. # Train for n iterations with low LR
  29. config["lr"] = 0.0001
  30. agent2 = PPOTrainer(env="CartPole-v0", config=config)
  31. agent2.restore(state)
  32. for _ in range(iterations):
  33. result = agent2.train()
  34. result["phase"] = 2
  35. result["timesteps_total"] += phase1_time # keep time moving forward
  36. reporter(**result)
  37. agent2.stop()
  38. if __name__ == "__main__":
  39. ray.init()
  40. args = parser.parse_args()
  41. config = {
  42. # Special flag signalling `my_train_fn` how many iters to do.
  43. "train-iterations": 2,
  44. "lr": 0.01,
  45. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  46. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  47. "num_workers": 0,
  48. "framework": args.framework,
  49. }
  50. resources = PPOTrainer.default_resource_request(config)
  51. tune.run(my_train_fn, resources_per_trial=resources, config=config)