custom_model_loss_and_metrics.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. """Example of using custom_loss() with an imitation learning loss.
  2. The default input file is too small to learn a good policy, but you can
  3. generate new experiences for IL training as follows:
  4. To generate experiences:
  5. $ ./train.py --run=PG --config='{"output": "/tmp/cartpole"}' --env=CartPole-v0
  6. To train on experiences with joint PG + IL loss:
  7. $ python custom_loss.py --input-files=/tmp/cartpole
  8. """
  9. import argparse
  10. from pathlib import Path
  11. import os
  12. import ray
  13. from ray import tune
  14. from ray.rllib.examples.models.custom_loss_model import CustomLossModel, \
  15. TorchCustomLossModel
  16. from ray.rllib.models import ModelCatalog
  17. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  18. from ray.rllib.utils.framework import try_import_tf
  19. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \
  20. LEARNER_STATS_KEY
  21. tf1, tf, tfv = try_import_tf()
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument(
  24. "--run",
  25. type=str,
  26. default="PG",
  27. help="The RLlib-registered algorithm to use.")
  28. parser.add_argument(
  29. "--framework",
  30. choices=["tf", "tf2", "tfe", "torch"],
  31. default="tf",
  32. help="The DL framework specifier.")
  33. parser.add_argument("--stop-iters", type=int, default=200)
  34. parser.add_argument(
  35. "--input-files",
  36. type=str,
  37. default=os.path.join(
  38. os.path.dirname(os.path.abspath(__file__)),
  39. "../tests/data/cartpole/small.json"))
  40. if __name__ == "__main__":
  41. ray.init()
  42. args = parser.parse_args()
  43. # Bazel makes it hard to find files specified in `args` (and `data`).
  44. # Look for them here.
  45. if not os.path.exists(args.input_files):
  46. # This script runs in the ray/rllib/examples dir.
  47. rllib_dir = Path(__file__).parent.parent
  48. input_dir = rllib_dir.absolute().joinpath(args.input_files)
  49. args.input_files = str(input_dir)
  50. ModelCatalog.register_custom_model(
  51. "custom_loss", TorchCustomLossModel
  52. if args.framework == "torch" else CustomLossModel)
  53. config = {
  54. "env": "CartPole-v0",
  55. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  56. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  57. "num_workers": 0,
  58. "model": {
  59. "custom_model": "custom_loss",
  60. "custom_model_config": {
  61. "input_files": args.input_files,
  62. },
  63. },
  64. "framework": args.framework,
  65. }
  66. stop = {
  67. "training_iteration": args.stop_iters,
  68. }
  69. analysis = tune.run(args.run, config=config, stop=stop, verbose=1)
  70. info = analysis.results[next(iter(analysis.results))]["info"]
  71. # Torch metrics structure.
  72. if args.framework == "torch":
  73. assert LEARNER_STATS_KEY in info[LEARNER_INFO][DEFAULT_POLICY_ID]
  74. assert "model" in info[LEARNER_INFO][DEFAULT_POLICY_ID]
  75. assert "custom_metrics" in info[LEARNER_INFO][DEFAULT_POLICY_ID]
  76. # TODO: (sven) Make sure the metrics structure gets unified between
  77. # tf and torch. Tf should work like current torch:
  78. # info:
  79. # learner:
  80. # [policy_id]
  81. # learner_stats: [return values of policy's `stats_fn`]
  82. # model: [return values of ModelV2's `metrics` method]
  83. # custom_metrics: [return values of callback: `on_learn_on_batch`]
  84. else:
  85. assert "model" in info[LEARNER_INFO][DEFAULT_POLICY_ID][
  86. LEARNER_STATS_KEY]