custom_logger.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. """
  2. This example script demonstrates how one can define a custom logger
  3. object for any RLlib Algorithm via the Algorithm's config's `logger_config` property.
  4. By default (logger_config=None), RLlib will construct a tune
  5. UnifiedLogger object, which logs JSON, CSV, and TBX output.
  6. Below examples include:
  7. - Disable logging entirely.
  8. - Using only one of tune's Json, CSV, or TBX loggers.
  9. - Defining a custom logger (by sub-classing tune.logger.py::Logger).
  10. """
  11. import argparse
  12. import os
  13. from ray.rllib.utils.test_utils import check_learning_achieved
  14. from ray.tune.logger import Logger, LegacyLoggerCallback
  15. from ray.tune.registry import get_trainable_cls
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument(
  18. "--run", type=str, default="PPO", help="The RLlib-registered algorithm to use."
  19. )
  20. parser.add_argument("--num-cpus", type=int, default=0)
  21. parser.add_argument(
  22. "--framework",
  23. choices=["tf", "tf2", "torch"],
  24. default="torch",
  25. help="The DL framework specifier.",
  26. )
  27. parser.add_argument(
  28. "--as-test",
  29. action="store_true",
  30. help="Whether this script should be run as a test: --stop-reward must "
  31. "be achieved within --stop-timesteps AND --stop-iters.",
  32. )
  33. parser.add_argument(
  34. "--stop-iters", type=int, default=200, help="Number of iterations to train."
  35. )
  36. parser.add_argument(
  37. "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
  38. )
  39. parser.add_argument(
  40. "--stop-reward", type=float, default=150.0, help="Reward at which we stop training."
  41. )
  42. class MyPrintLogger(Logger):
  43. """Logs results by simply printing out everything."""
  44. def _init(self):
  45. # Custom init function.
  46. print("Initializing ...")
  47. # Setting up our log-line prefix.
  48. self.prefix = self.config.get("logger_config").get("prefix")
  49. def on_result(self, result: dict):
  50. # Define, what should happen on receiving a `result` (dict).
  51. print(f"{self.prefix}: {result}")
  52. def close(self):
  53. # Releases all resources used by this logger.
  54. print("Closing")
  55. def flush(self):
  56. # Flushing all possible disk writes to permanent storage.
  57. print("Flushing ;)", flush=True)
  58. if __name__ == "__main__":
  59. import ray
  60. from ray import air, tune
  61. args = parser.parse_args()
  62. ray.init(num_cpus=args.num_cpus or None)
  63. config = (
  64. get_trainable_cls(args.run)
  65. .get_default_config()
  66. .environment(
  67. "CartPole-v1" if args.run not in ["DDPG", "TD3"] else "Pendulum-v1"
  68. )
  69. # Run with tracing enabled for tf2.
  70. .framework(args.framework)
  71. # Setting up a custom logger config.
  72. # ----------------------------------
  73. # The following are different examples of custom logging setups:
  74. # 1) Disable logging entirely.
  75. # "logger_config": {
  76. # # Use the tune.logger.NoopLogger class for no logging.
  77. # "type": "ray.tune.logger.NoopLogger",
  78. # },
  79. # 2) Use tune's JsonLogger only.
  80. # Alternatively, use `CSVLogger` or `TBXLogger` instead of
  81. # `JsonLogger` in the "type" key below.
  82. # "logger_config": {
  83. # "type": "ray.tune.logger.JsonLogger",
  84. # # Optional: Custom logdir (do not define this here
  85. # # for using ~/ray_results/...).
  86. # "logdir": "/tmp",
  87. # },
  88. # 3) Custom logger (see `MyPrintLogger` class above).
  89. .debugging(
  90. logger_config={
  91. # Provide the class directly or via fully qualified class
  92. # path.
  93. "type": MyPrintLogger,
  94. # `config` keys:
  95. "prefix": "ABC",
  96. # Optional: Custom logdir (do not define this here
  97. # for using ~/ray_results/...).
  98. # "logdir": "/somewhere/on/my/file/system/"
  99. }
  100. )
  101. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  102. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  103. )
  104. stop = {
  105. "training_iteration": args.stop_iters,
  106. "timesteps_total": args.stop_timesteps,
  107. "episode_reward_mean": args.stop_reward,
  108. }
  109. tuner = tune.Tuner(
  110. args.run,
  111. param_space=config.to_dict(),
  112. run_config=air.RunConfig(
  113. stop=stop,
  114. verbose=2,
  115. callbacks=[LegacyLoggerCallback(MyPrintLogger)],
  116. ),
  117. )
  118. results = tuner.fit()
  119. if args.as_test:
  120. check_learning_achieved(results, args.stop_reward)
  121. ray.shutdown()