lr_scheduler.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. """
  2. @Time : 2021-01-21 10:52:47
  3. @File : lr_scheduler.py
  4. @Author : Abtion
  5. @Email : abtion{at}outlook.com
  6. """
  7. import math
  8. import warnings
  9. from bisect import bisect_right
  10. from typing import List
  11. import torch
  12. from torch.optim.lr_scheduler import _LRScheduler
  13. __all__ = ["WarmupMultiStepLR", "WarmupCosineAnnealingLR"]
  14. class WarmupMultiStepLR(_LRScheduler):
  15. def __init__(
  16. self,
  17. optimizer: torch.optim.Optimizer,
  18. milestones: List[int],
  19. gamma: float = 0.1,
  20. warmup_factor: float = 0.001,
  21. warmup_epochs: int = 2,
  22. warmup_method: str = "linear",
  23. last_epoch: int = -1,
  24. **kwargs,
  25. ):
  26. if not list(milestones) == sorted(milestones):
  27. raise ValueError(
  28. "Milestones should be a list of" " increasing integers. Got {}", milestones
  29. )
  30. self.milestones = milestones
  31. self.gamma = gamma
  32. self.warmup_factor = warmup_factor
  33. self.warmup_epochs = warmup_epochs
  34. self.warmup_method = warmup_method
  35. super().__init__(optimizer, last_epoch)
  36. def get_lr(self) -> List[float]:
  37. warmup_factor = _get_warmup_factor_at_iter(
  38. self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor
  39. )
  40. return [
  41. base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
  42. for base_lr in self.base_lrs
  43. ]
  44. def _compute_values(self) -> List[float]:
  45. # The new interface
  46. return self.get_lr()
  47. class WarmupExponentialLR(_LRScheduler):
  48. """Decays the learning rate of each parameter group by gamma every epoch.
  49. When last_epoch=-1, sets initial lr as lr.
  50. Args:
  51. optimizer (Optimizer): Wrapped optimizer.
  52. gamma (float): Multiplicative factor of learning rate decay.
  53. last_epoch (int): The index of last epoch. Default: -1.
  54. verbose (bool): If ``True``, prints a message to stdout for
  55. each update. Default: ``False``.
  56. """
  57. def __init__(self, optimizer, gamma, last_epoch=-1, warmup_epochs=2, warmup_factor=1.0 / 3, verbose=False,
  58. **kwargs):
  59. self.gamma = gamma
  60. self.warmup_method = 'linear'
  61. self.warmup_epochs = warmup_epochs
  62. self.warmup_factor = warmup_factor
  63. super().__init__(optimizer, last_epoch, verbose)
  64. def get_lr(self):
  65. if not self._get_lr_called_within_step:
  66. warnings.warn("To get the last learning rate computed by the scheduler, "
  67. "please use `get_last_lr()`.", UserWarning)
  68. warmup_factor = _get_warmup_factor_at_iter(
  69. self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor
  70. )
  71. if self.last_epoch <= self.warmup_epochs:
  72. return [base_lr * warmup_factor
  73. for base_lr in self.base_lrs]
  74. return [group['lr'] * self.gamma
  75. for group in self.optimizer.param_groups]
  76. def _get_closed_form_lr(self):
  77. return [base_lr * self.gamma ** self.last_epoch
  78. for base_lr in self.base_lrs]
  79. class WarmupCosineAnnealingLR(_LRScheduler):
  80. r"""Set the learning rate of each parameter group using a cosine annealing
  81. schedule, where :math:`\eta_{max}` is set to the initial lr and
  82. :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
  83. .. math::
  84. \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
  85. \cos(\frac{T_{cur}}{T_{max}}\pi))
  86. When last_epoch=-1, sets initial lr as lr.
  87. It has been proposed in
  88. `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
  89. implements the cosine annealing part of SGDR, and not the restarts.
  90. Args:
  91. optimizer (Optimizer): Wrapped optimizer.
  92. T_max (int): Maximum number of iterations.
  93. eta_min (float): Minimum learning rate. Default: 0.
  94. last_epoch (int): The index of last epoch. Default: -1.
  95. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
  96. https://arxiv.org/abs/1608.03983
  97. """
  98. def __init__(
  99. self,
  100. optimizer: torch.optim.Optimizer,
  101. max_iters: int,
  102. delay_iters: int = 0,
  103. eta_min_lr: int = 0,
  104. warmup_factor: float = 0.001,
  105. warmup_epochs: int = 2,
  106. warmup_method: str = "linear",
  107. last_epoch=-1,
  108. **kwargs
  109. ):
  110. self.max_iters = max_iters
  111. self.delay_iters = delay_iters
  112. self.eta_min_lr = eta_min_lr
  113. self.warmup_factor = warmup_factor
  114. self.warmup_epochs = warmup_epochs
  115. self.warmup_method = warmup_method
  116. assert self.delay_iters >= self.warmup_epochs, "Scheduler delay iters must be larger than warmup iters"
  117. super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
  118. def get_lr(self) -> List[float]:
  119. if self.last_epoch <= self.warmup_epochs:
  120. warmup_factor = _get_warmup_factor_at_iter(
  121. self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor,
  122. )
  123. return [
  124. base_lr * warmup_factor for base_lr in self.base_lrs
  125. ]
  126. elif self.last_epoch <= self.delay_iters:
  127. return self.base_lrs
  128. else:
  129. return [
  130. self.eta_min_lr + (base_lr - self.eta_min_lr) *
  131. (1 + math.cos(
  132. math.pi * (self.last_epoch - self.delay_iters) / (self.max_iters - self.delay_iters))) / 2
  133. for base_lr in self.base_lrs]
  134. def _get_warmup_factor_at_iter(
  135. method: str, iter: int, warmup_iters: int, warmup_factor: float
  136. ) -> float:
  137. """
  138. Return the learning rate warmup factor at a specific iteration.
  139. See https://arxiv.org/abs/1706.02677 for more details.
  140. Args:
  141. method (str): warmup method; either "constant" or "linear".
  142. iter (int): iteration at which to calculate the warmup factor.
  143. warmup_iters (int): the number of warmup iterations.
  144. warmup_factor (float): the base warmup factor (the meaning changes according
  145. to the method used).
  146. Returns:
  147. float: the effective warmup factor at the given iteration.
  148. """
  149. if iter >= warmup_iters:
  150. return 1.0
  151. if method == "constant":
  152. return warmup_factor
  153. elif method == "linear":
  154. alpha = iter / warmup_iters
  155. return warmup_factor * (1 - alpha) + alpha
  156. else:
  157. raise ValueError("Unknown warmup method: {}".format(method))