schedule.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from abc import ABCMeta, abstractmethod
  2. from ray.rllib.utils.annotations import DeveloperAPI
  3. from ray.rllib.utils.framework import try_import_tf
  4. tf1, tf, tfv = try_import_tf()
  5. @DeveloperAPI
  6. class Schedule(metaclass=ABCMeta):
  7. """Schedule classes implement various time-dependent scheduling schemas.
  8. - Constant behavior.
  9. - Linear decay.
  10. - Piecewise decay.
  11. - Exponential decay.
  12. Useful for backend-agnostic rate/weight changes for learning rates,
  13. exploration epsilons, beta parameters for prioritized replay, loss weights
  14. decay, etc..
  15. Each schedule can be called directly with the `t` (absolute time step)
  16. value and returns the value dependent on the Schedule and the passed time.
  17. """
  18. def __init__(self, framework):
  19. self.framework = framework
  20. def value(self, t):
  21. """Generates the value given a timestep (based on schedule's logic).
  22. Args:
  23. t (int): The time step. This could be a tf.Tensor.
  24. Returns:
  25. any: The calculated value depending on the schedule and `t`.
  26. """
  27. if self.framework in ["tf2", "tf", "tfe"]:
  28. return self._tf_value_op(t)
  29. return self._value(t)
  30. def __call__(self, t):
  31. """Simply calls self.value(t). Implemented to make Schedules callable.
  32. """
  33. return self.value(t)
  34. @DeveloperAPI
  35. @abstractmethod
  36. def _value(self, t):
  37. """
  38. Returns the value based on a time step input.
  39. Args:
  40. t (int): The time step. This could be a tf.Tensor.
  41. Returns:
  42. any: The calculated value depending on the schedule and `t`.
  43. """
  44. raise NotImplementedError
  45. @DeveloperAPI
  46. def _tf_value_op(self, t):
  47. """
  48. Returns the tf-op that calculates the value based on a time step input.
  49. Args:
  50. t (tf.Tensor): The time step op (int tf.Tensor).
  51. Returns:
  52. tf.Tensor: The calculated value depending on the schedule and `t`.
  53. """
  54. # By default (most of the time), tf should work with python code.
  55. # Override only if necessary.
  56. return self._value(t)