mock.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import os
  2. import pickle
  3. import numpy as np
  4. from ray.tune import result as tune_result
  5. from ray.rllib.agents.trainer import Trainer, with_common_config
  6. class _MockTrainer(Trainer):
  7. """Mock trainer for use in tests"""
  8. _name = "MockTrainer"
  9. _default_config = with_common_config({
  10. "mock_error": False,
  11. "persistent_error": False,
  12. "test_variable": 1,
  13. "num_workers": 0,
  14. "user_checkpoint_freq": 0,
  15. "framework": "tf",
  16. })
  17. @classmethod
  18. def default_resource_request(cls, config):
  19. return None
  20. def _init(self, config, env_creator):
  21. self.info = None
  22. self.restored = False
  23. def step(self):
  24. if self.config["mock_error"] and self.iteration == 1 \
  25. and (self.config["persistent_error"] or not self.restored):
  26. raise Exception("mock error")
  27. result = dict(
  28. episode_reward_mean=10,
  29. episode_len_mean=10,
  30. timesteps_this_iter=10,
  31. info={})
  32. if self.config["user_checkpoint_freq"] > 0 and self.iteration > 0:
  33. if self.iteration % self.config["user_checkpoint_freq"] == 0:
  34. result.update({tune_result.SHOULD_CHECKPOINT: True})
  35. return result
  36. def save_checkpoint(self, checkpoint_dir):
  37. path = os.path.join(checkpoint_dir, "mock_agent.pkl")
  38. with open(path, "wb") as f:
  39. pickle.dump(self.info, f)
  40. return path
  41. def load_checkpoint(self, checkpoint_path):
  42. with open(checkpoint_path, "rb") as f:
  43. info = pickle.load(f)
  44. self.info = info
  45. self.restored = True
  46. def _register_if_needed(self, env_object, config):
  47. pass
  48. def set_info(self, info):
  49. self.info = info
  50. return info
  51. def get_info(self, sess=None):
  52. return self.info
  53. class _SigmoidFakeData(_MockTrainer):
  54. """Trainer that returns sigmoid learning curves.
  55. This can be helpful for evaluating early stopping algorithms."""
  56. _name = "SigmoidFakeData"
  57. _default_config = with_common_config({
  58. "width": 100,
  59. "height": 100,
  60. "offset": 0,
  61. "iter_time": 10,
  62. "iter_timesteps": 1,
  63. "num_workers": 0,
  64. })
  65. def step(self):
  66. i = max(0, self.iteration - self.config["offset"])
  67. v = np.tanh(float(i) / self.config["width"])
  68. v *= self.config["height"]
  69. return dict(
  70. episode_reward_mean=v,
  71. episode_len_mean=v,
  72. timesteps_this_iter=self.config["iter_timesteps"],
  73. time_this_iter_s=self.config["iter_time"],
  74. info={})
  75. class _ParameterTuningTrainer(_MockTrainer):
  76. _name = "ParameterTuningTrainer"
  77. _default_config = with_common_config({
  78. "reward_amt": 10,
  79. "dummy_param": 10,
  80. "dummy_param2": 15,
  81. "iter_time": 10,
  82. "iter_timesteps": 1,
  83. "num_workers": 0,
  84. })
  85. def step(self):
  86. return dict(
  87. episode_reward_mean=self.config["reward_amt"] * self.iteration,
  88. episode_len_mean=self.config["reward_amt"],
  89. timesteps_this_iter=self.config["iter_timesteps"],
  90. time_this_iter_s=self.config["iter_time"],
  91. info={})
  92. def _trainer_import_failed(trace):
  93. """Returns dummy agent class for if PyTorch etc. is not installed."""
  94. class _TrainerImportFailed(Trainer):
  95. _name = "TrainerImportFailed"
  96. _default_config = with_common_config({})
  97. def setup(self, config):
  98. raise ImportError(trace)
  99. return _TrainerImportFailed