custom_logger.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """
  2. This example script demonstrates how one can define a custom logger
  3. object for any RLlib Trainer via the Trainer's config dict's
  4. "logger_config" key.
  5. By default (logger_config=None), RLlib will construct a tune
  6. UnifiedLogger object, which logs JSON, CSV, and TBX output.
  7. Below examples include:
  8. - Disable logging entirely.
  9. - Using only one of tune's Json, CSV, or TBX loggers.
  10. - Defining a custom logger (by sub-classing tune.logger.py::Logger).
  11. """
  12. import argparse
  13. import os
  14. from ray.rllib.utils.test_utils import check_learning_achieved
  15. from ray.tune.logger import Logger
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument(
  18. "--run",
  19. type=str,
  20. default="PPO",
  21. help="The RLlib-registered algorithm to use.")
  22. parser.add_argument("--num-cpus", type=int, default=0)
  23. parser.add_argument(
  24. "--framework",
  25. choices=["tf", "tf2", "tfe", "torch"],
  26. default="tf",
  27. help="The DL framework specifier.")
  28. parser.add_argument(
  29. "--as-test",
  30. action="store_true",
  31. help="Whether this script should be run as a test: --stop-reward must "
  32. "be achieved within --stop-timesteps AND --stop-iters.")
  33. parser.add_argument(
  34. "--stop-iters",
  35. type=int,
  36. default=200,
  37. help="Number of iterations to train.")
  38. parser.add_argument(
  39. "--stop-timesteps",
  40. type=int,
  41. default=100000,
  42. help="Number of timesteps to train.")
  43. parser.add_argument(
  44. "--stop-reward",
  45. type=float,
  46. default=150.0,
  47. help="Reward at which we stop training.")
  48. class MyPrintLogger(Logger):
  49. """Logs results by simply printing out everything.
  50. """
  51. def _init(self):
  52. # Custom init function.
  53. print("Initializing ...")
  54. # Setting up our log-line prefix.
  55. self.prefix = self.config.get("logger_config").get("prefix")
  56. def on_result(self, result: dict):
  57. # Define, what should happen on receiving a `result` (dict).
  58. print(f"{self.prefix}: {result}")
  59. def close(self):
  60. # Releases all resources used by this logger.
  61. print("Closing")
  62. def flush(self):
  63. # Flushing all possible disk writes to permanent storage.
  64. print("Flushing ;)", flush=True)
  65. if __name__ == "__main__":
  66. import ray
  67. from ray import tune
  68. args = parser.parse_args()
  69. ray.init(num_cpus=args.num_cpus or None)
  70. config = {
  71. "env": "CartPole-v0"
  72. if args.run not in ["DDPG", "TD3"] else "Pendulum-v1",
  73. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  74. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  75. "framework": args.framework,
  76. # Run with tracing enabled for tfe/tf2.
  77. "eager_tracing": args.framework in ["tfe", "tf2"],
  78. # Setting up a custom logger config.
  79. # ----------------------------------
  80. # The following are different examples of custom logging setups:
  81. # 1) Disable logging entirely.
  82. # "logger_config": {
  83. # # Use the tune.logger.NoopLogger class for no logging.
  84. # "type": "ray.tune.logger.NoopLogger",
  85. # },
  86. # 2) Use tune's JsonLogger only.
  87. # Alternatively, use `CSVLogger` or `TBXLogger` instead of
  88. # `JsonLogger` in the "type" key below.
  89. # "logger_config": {
  90. # "type": "ray.tune.logger.JsonLogger",
  91. # # Optional: Custom logdir (do not define this here
  92. # # for using ~/ray_results/...).
  93. # "logdir": "/tmp",
  94. # },
  95. # 3) Custom logger (see `MyPrintLogger` class above).
  96. "logger_config": {
  97. # Provide the class directly or via fully qualified class
  98. # path.
  99. "type": MyPrintLogger,
  100. # `config` keys:
  101. "prefix": "ABC",
  102. # Optional: Custom logdir (do not define this here
  103. # for using ~/ray_results/...).
  104. # "logdir": "/somewhere/on/my/file/system/"
  105. }
  106. }
  107. stop = {
  108. "training_iteration": args.stop_iters,
  109. "timesteps_total": args.stop_timesteps,
  110. "episode_reward_mean": args.stop_reward,
  111. }
  112. results = tune.run(
  113. args.run, config=config, stop=stop, verbose=2, loggers=[MyPrintLogger])
  114. if args.as_test:
  115. check_learning_achieved(results, args.stop_reward)
  116. ray.shutdown()