exponential_schedule.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from ray.rllib.utils.annotations import override
  2. from ray.rllib.utils.framework import try_import_torch
  3. from ray.rllib.utils.schedules.schedule import Schedule
  4. torch, _ = try_import_torch()
  5. class ExponentialSchedule(Schedule):
  6. def __init__(self,
  7. schedule_timesteps,
  8. framework,
  9. initial_p=1.0,
  10. decay_rate=0.1):
  11. """
  12. Exponential decay schedule from initial_p to final_p over
  13. schedule_timesteps. After this many time steps always `final_p` is
  14. returned.
  15. Agrs:
  16. schedule_timesteps (int): Number of time steps for which to
  17. linearly anneal initial_p to final_p
  18. initial_p (float): Initial output value.
  19. decay_rate (float): The percentage of the original value after
  20. 100% of the time has been reached (see formula above).
  21. >0.0: The smaller the decay-rate, the stronger the decay.
  22. 1.0: No decay at all.
  23. framework (Optional[str]): One of "tf", "torch", or None.
  24. """
  25. super().__init__(framework=framework)
  26. assert schedule_timesteps > 0
  27. self.schedule_timesteps = schedule_timesteps
  28. self.initial_p = initial_p
  29. self.decay_rate = decay_rate
  30. @override(Schedule)
  31. def _value(self, t):
  32. """Returns the result of: initial_p * decay_rate ** (`t`/t_max)
  33. """
  34. if self.framework == "torch" and torch and isinstance(t, torch.Tensor):
  35. t = t.float()
  36. return self.initial_p * \
  37. self.decay_rate ** (t / self.schedule_timesteps)