custom_model_loss_and_metrics.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. """Example of using custom_loss() with an imitation learning loss under the Policy
  2. and ModelV2 API.
  3. The default input file is too small to learn a good policy, but you can
  4. generate new experiences for IL training as follows:
  5. To generate experiences:
  6. $ ./train.py --run=PG --config='{"output": "/tmp/cartpole"}' --env=CartPole-v1
  7. To train on experiences with joint PG + IL loss:
  8. $ python custom_loss.py --input-files=/tmp/cartpole
  9. """
  10. import argparse
  11. from pathlib import Path
  12. import os
  13. import ray
  14. from ray import air, tune
  15. from ray.rllib.examples.models.custom_loss_model import (
  16. CustomLossModel,
  17. TorchCustomLossModel,
  18. )
  19. from ray.rllib.models import ModelCatalog
  20. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  21. from ray.rllib.utils.framework import try_import_tf
  22. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
  23. from ray.tune.registry import get_trainable_cls
  24. tf1, tf, tfv = try_import_tf()
  25. parser = argparse.ArgumentParser()
  26. parser.add_argument(
  27. "--run", type=str, default="PG", help="The RLlib-registered algorithm to use."
  28. )
  29. parser.add_argument(
  30. "--framework",
  31. choices=["tf", "tf2", "torch"],
  32. default="torch",
  33. help="The DL framework specifier.",
  34. )
  35. parser.add_argument("--stop-iters", type=int, default=200)
  36. parser.add_argument(
  37. "--input-files",
  38. type=str,
  39. default=os.path.join(
  40. os.path.dirname(os.path.abspath(__file__)), "../tests/data/cartpole/small.json"
  41. ),
  42. )
  43. if __name__ == "__main__":
  44. ray.init()
  45. args = parser.parse_args()
  46. # Bazel makes it hard to find files specified in `args` (and `data`).
  47. # Look for them here.
  48. if not os.path.exists(args.input_files):
  49. # This script runs in the ray/rllib/examples dir.
  50. rllib_dir = Path(__file__).parent.parent
  51. input_dir = rllib_dir.absolute().joinpath(args.input_files)
  52. args.input_files = str(input_dir)
  53. ModelCatalog.register_custom_model(
  54. "custom_loss",
  55. TorchCustomLossModel if args.framework == "torch" else CustomLossModel,
  56. )
  57. # TODO (Kourosh): This example needs to be migrated to the new RLModule / Learner
  58. # API. Users should just inherit the Learner and extend the loss_fn.
  59. config = (
  60. get_trainable_cls(args.run)
  61. .get_default_config()
  62. .environment("CartPole-v1")
  63. .framework(args.framework)
  64. .rollouts(num_rollout_workers=0)
  65. .training(
  66. model={
  67. "custom_model": "custom_loss",
  68. "custom_model_config": {
  69. "input_files": args.input_files,
  70. },
  71. },
  72. _enable_learner_api=False,
  73. )
  74. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  75. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  76. .rl_module(_enable_rl_module_api=False)
  77. )
  78. stop = {
  79. "training_iteration": args.stop_iters,
  80. }
  81. tuner = tune.Tuner(
  82. args.run,
  83. param_space=config,
  84. run_config=air.RunConfig(stop=stop, verbose=1),
  85. )
  86. results = tuner.fit()
  87. info = results.get_best_result().metrics["info"]
  88. # Torch metrics structure.
  89. if args.framework == "torch":
  90. assert LEARNER_STATS_KEY in info[LEARNER_INFO][DEFAULT_POLICY_ID]
  91. assert "model" in info[LEARNER_INFO][DEFAULT_POLICY_ID]
  92. assert "custom_metrics" in info[LEARNER_INFO][DEFAULT_POLICY_ID]
  93. # TODO: (sven) Make sure the metrics structure gets unified between
  94. # tf and torch. Tf should work like current torch:
  95. # info:
  96. # learner:
  97. # [policy_id]
  98. # learner_stats: [return values of policy's `stats_fn`]
  99. # model: [return values of ModelV2's `metrics` method]
  100. # custom_metrics: [return values of callback: `on_learn_on_batch`]
  101. else:
  102. assert "model" in info[LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY]