recsim_with_slateq.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. """The SlateQ algorithm for recommendation"""
  2. import argparse
  3. from datetime import datetime
  4. import ray
  5. from ray import tune
  6. from ray.rllib.agents import slateq
  7. from ray.rllib.agents import dqn
  8. from ray.rllib.agents.slateq.slateq import ALL_SLATEQ_STRATEGIES
  9. from ray.rllib.env.wrappers.recsim_wrapper import env_name as recsim_env_name
  10. from ray.tune.logger import pretty_print
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument(
  13. "--agent",
  14. type=str,
  15. default="SlateQ",
  16. help=("Select agent policy. Choose from: DQN and SlateQ. "
  17. "Default value: SlateQ."),
  18. )
  19. parser.add_argument(
  20. "--strategy",
  21. type=str,
  22. default="QL",
  23. help=("Strategy for the SlateQ agent. Choose from: " +
  24. ", ".join(ALL_SLATEQ_STRATEGIES) + ". "
  25. "Default value: QL. Ignored when using Tune."),
  26. )
  27. parser.add_argument(
  28. "--use-tune",
  29. action="store_true",
  30. help=("Run with Tune so that the results are logged into Tensorboard. "
  31. "For debugging, it's easier to run without Ray Tune."),
  32. )
  33. parser.add_argument("--tune-num-samples", type=int, default=10)
  34. parser.add_argument("--env-slate-size", type=int, default=2)
  35. parser.add_argument("--env-seed", type=int, default=0)
  36. parser.add_argument(
  37. "--num-gpus",
  38. type=float,
  39. default=0.,
  40. help="Only used if running with Tune.")
  41. parser.add_argument(
  42. "--num-workers",
  43. type=int,
  44. default=0,
  45. help="Only used if running with Tune.")
  46. def main():
  47. args = parser.parse_args()
  48. ray.init()
  49. if args.agent not in ["DQN", "SlateQ"]:
  50. raise ValueError(args.agent)
  51. env_config = {
  52. "slate_size": args.env_slate_size,
  53. "seed": args.env_seed,
  54. "convert_to_discrete_action_space": args.agent == "DQN",
  55. }
  56. if args.use_tune:
  57. time_signature = datetime.now().strftime("%Y-%m-%d_%H_%M_%S")
  58. name = f"SlateQ/{args.agent}-seed{args.env_seed}-{time_signature}"
  59. if args.agent == "DQN":
  60. tune.run(
  61. "DQN",
  62. stop={"timesteps_total": 4000000},
  63. name=name,
  64. config={
  65. "env": recsim_env_name,
  66. "num_gpus": args.num_gpus,
  67. "num_workers": args.num_workers,
  68. "env_config": env_config,
  69. },
  70. num_samples=args.tune_num_samples,
  71. verbose=1)
  72. else:
  73. tune.run(
  74. "SlateQ",
  75. stop={"timesteps_total": 4000000},
  76. name=name,
  77. config={
  78. "env": recsim_env_name,
  79. "num_gpus": args.num_gpus,
  80. "num_workers": args.num_workers,
  81. "slateq_strategy": tune.grid_search(ALL_SLATEQ_STRATEGIES),
  82. "env_config": env_config,
  83. },
  84. num_samples=args.tune_num_samples,
  85. verbose=1)
  86. else:
  87. # directly run using the trainer interface (good for debugging)
  88. if args.agent == "DQN":
  89. config = dqn.DEFAULT_CONFIG.copy()
  90. config["num_gpus"] = 0
  91. config["num_workers"] = 0
  92. config["env_config"] = env_config
  93. trainer = dqn.DQNTrainer(config=config, env=recsim_env_name)
  94. else:
  95. config = slateq.DEFAULT_CONFIG.copy()
  96. config["num_gpus"] = 0
  97. config["num_workers"] = 0
  98. config["slateq_strategy"] = args.strategy
  99. config["env_config"] = env_config
  100. trainer = slateq.SlateQTrainer(config=config, env=recsim_env_name)
  101. for i in range(10):
  102. result = trainer.train()
  103. print(pretty_print(result))
  104. ray.shutdown()
  105. if __name__ == "__main__":
  106. main()