lr_schedules.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Implementation of learning rate schedules.
  6. Taken and modified from PyTorch v1.0.1 source
  7. https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py
  8. """
  9. import argparse
  10. from torch.optim import Optimizer
  11. import math
  12. from deepspeed.utils import logger
  13. LR_SCHEDULE = 'lr_schedule'
  14. LR_RANGE_TEST = 'LRRangeTest'
  15. ONE_CYCLE = 'OneCycle'
  16. WARMUP_LR = 'WarmupLR'
  17. WARMUP_DECAY_LR = 'WarmupDecayLR'
  18. WARMUP_COSINE_LR = 'WarmupCosineLR'
  19. VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR, WARMUP_COSINE_LR]
  20. LR_RANGE_TEST_MIN_LR = 'lr_range_test_min_lr'
  21. LR_RANGE_TEST_STEP_RATE = 'lr_range_test_step_rate'
  22. LR_RANGE_TEST_STEP_SIZE = 'lr_range_test_step_size'
  23. LR_RANGE_TEST_STAIRCASE = 'lr_range_test_staircase'
  24. EDGE_VALUE = 'edge_value'
  25. MID_VALUE = 'mid_value'
  26. CYCLE_FIRST_STEP_SIZE = 'cycle_first_step_size'
  27. CYCLE_FIRST_STAIR_COUNT = 'cycle_first_stair_count'
  28. CYCLE_SECOND_STEP_SIZE = 'cycle_second_step_size'
  29. CYCLE_SECOND_STAIR_COUNT = 'cycle_second_stair_count'
  30. DECAY_STEP_SIZE = 'decay_step_size'
  31. CYCLE_MIN_LR = 'cycle_min_lr'
  32. CYCLE_MAX_LR = 'cycle_max_lr'
  33. DECAY_LR_RATE = 'decay_lr_rate'
  34. CYCLE_MIN_MOM = 'cycle_min_mom'
  35. CYCLE_MAX_MOM = 'cycle_max_mom'
  36. DECAY_MOM_RATE = 'decay_mom_rate'
  37. WARMUP_MIN_LR = 'warmup_min_lr'
  38. WARMUP_MAX_LR = 'warmup_max_lr'
  39. WARMUP_NUM_STEPS = 'warmup_num_steps'
  40. WARMUP_TYPE = 'warmup_type'
  41. WARMUP_LOG_RATE = 'log'
  42. WARMUP_LINEAR_RATE = 'linear'
  43. WARMUP_MIN_RATIO = 'warmup_min_ratio'
  44. COS_MIN_RATIO = 'cos_min_ratio'
  45. TOTAL_NUM_STEPS = 'total_num_steps'
  46. def add_tuning_arguments(parser):
  47. group = parser.add_argument_group('Convergence Tuning', 'Convergence tuning configurations')
  48. # LR scheduler
  49. group.add_argument('--lr_schedule', type=str, default=None, help='LR schedule for training.')
  50. # Learning rate range test
  51. group.add_argument("--lr_range_test_min_lr", type=float, default=0.001, help='Starting lr value.')
  52. group.add_argument("--lr_range_test_step_rate", type=float, default=1.0, help='scaling rate for LR range test.')
  53. group.add_argument("--lr_range_test_step_size", type=int, default=1000, help='training steps per LR change.')
  54. group.add_argument("--lr_range_test_staircase",
  55. type=bool,
  56. default=False,
  57. help='use staircase scaling for LR range test.')
  58. # OneCycle schedule
  59. group.add_argument("--cycle_first_step_size",
  60. type=int,
  61. default=1000,
  62. help='size of first step of 1Cycle schedule (training steps).')
  63. group.add_argument("--cycle_first_stair_count",
  64. type=int,
  65. default=-1,
  66. help='first stair count for 1Cycle schedule.')
  67. group.add_argument("--cycle_second_step_size",
  68. type=int,
  69. default=-1,
  70. help='size of second step of 1Cycle schedule (default first_step_size).')
  71. group.add_argument("--cycle_second_stair_count",
  72. type=int,
  73. default=-1,
  74. help='second stair count for 1Cycle schedule.')
  75. group.add_argument("--decay_step_size",
  76. type=int,
  77. default=1000,
  78. help='size of intervals for applying post cycle decay (training steps).')
  79. # 1Cycle LR
  80. group.add_argument("--cycle_min_lr", type=float, default=0.01, help='1Cycle LR lower bound.')
  81. group.add_argument("--cycle_max_lr", type=float, default=0.1, help='1Cycle LR upper bound.')
  82. group.add_argument("--decay_lr_rate", type=float, default=0.0, help='post cycle LR decay rate.')
  83. # 1Cycle Momentum
  84. group.add_argument('--cycle_momentum', default=False, action='store_true', help='Enable 1Cycle momentum schedule.')
  85. group.add_argument("--cycle_min_mom", type=float, default=0.8, help='1Cycle momentum lower bound.')
  86. group.add_argument("--cycle_max_mom", type=float, default=0.9, help='1Cycle momentum upper bound.')
  87. group.add_argument("--decay_mom_rate", type=float, default=0.0, help='post cycle momentum decay rate.')
  88. # Warmup LR
  89. group.add_argument('--warmup_min_lr', type=float, default=0, help='WarmupLR minimum/initial LR value')
  90. group.add_argument('--warmup_max_lr', type=float, default=0.001, help='WarmupLR maximum LR value.')
  91. group.add_argument('--warmup_num_steps', type=int, default=1000, help='WarmupLR step count for LR warmup.')
  92. group.add_argument('--warmup_type',
  93. type=str,
  94. default=WARMUP_LOG_RATE,
  95. help='WarmupLR increasing function during warmup')
  96. # WarmUP cos LR
  97. group.add_argument("--warmup_min_ratio", type=float, default=0.01, help='Cosine LR lower bound.')
  98. group.add_argument("--cos_min_ratio", type=float, default=0.01, help='Cosine LR lower bound.')
  99. return parser
  100. def parse_arguments():
  101. parser = argparse.ArgumentParser()
  102. parser = add_tuning_arguments(parser)
  103. lr_sched_args, unknown_args = parser.parse_known_args()
  104. return lr_sched_args, unknown_args
  105. def override_lr_range_test_params(args, params):
  106. if hasattr(args, LR_RANGE_TEST_MIN_LR) and args.lr_range_test_min_lr is not None:
  107. params[LR_RANGE_TEST_MIN_LR] = args.lr_range_test_min_lr
  108. if hasattr(args, LR_RANGE_TEST_STEP_RATE) and args.lr_range_test_step_rate is not None:
  109. params[LR_RANGE_TEST_STEP_RATE] = args.lr_range_test_step_rate
  110. if hasattr(args, LR_RANGE_TEST_STEP_SIZE) and args.lr_range_test_step_size is not None:
  111. params[LR_RANGE_TEST_STEP_SIZE] = args.lr_range_test_step_size
  112. if hasattr(args, LR_RANGE_TEST_STAIRCASE) and args.lr_range_test_staircase is not None:
  113. params[LR_RANGE_TEST_STAIRCASE] = args.lr_range_test_staircase
  114. def override_1cycle_params(args, params):
  115. if hasattr(args, CYCLE_FIRST_STEP_SIZE) and args.cycle_first_step_size is not None:
  116. params[CYCLE_FIRST_STEP_SIZE] = args.cycle_first_step_size
  117. if hasattr(args, CYCLE_FIRST_STAIR_COUNT) and args.cycle_first_stair_count is not None:
  118. params[CYCLE_FIRST_STAIR_COUNT] = args.cycle_first_stair_count
  119. if hasattr(args, CYCLE_SECOND_STEP_SIZE) and args.cycle_second_step_size is not None:
  120. params[CYCLE_SECOND_STEP_SIZE] = args.cycle_second_step_size
  121. if hasattr(args, CYCLE_SECOND_STAIR_COUNT) and args.cycle_second_stair_count is not None:
  122. params[CYCLE_SECOND_STAIR_COUNT] = args.cycle_second_stair_count
  123. if hasattr(args, DECAY_STEP_SIZE) and args.decay_step_size is not None:
  124. params[DECAY_STEP_SIZE] = args.decay_step_size
  125. # 1Cycle LR params
  126. if hasattr(args, CYCLE_MIN_LR) and args.cycle_min_lr is not None:
  127. params[CYCLE_MIN_LR] = args.cycle_min_lr
  128. if hasattr(args, CYCLE_MAX_LR) and args.cycle_max_lr is not None:
  129. params[CYCLE_MAX_LR] = args.cycle_max_lr
  130. if hasattr(args, DECAY_LR_RATE) and args.decay_lr_rate is not None:
  131. params[DECAY_LR_RATE] = args.decay_lr_rate
  132. # 1Cycle MOM params
  133. if hasattr(args, CYCLE_MIN_MOM) and args.cycle_min_mom is not None:
  134. params[CYCLE_MIN_MOM] = args.cycle_min_mom
  135. if hasattr(args, CYCLE_MAX_MOM) and args.cycle_max_mom is not None:
  136. params[CYCLE_MAX_MOM] = args.cycle_max_mom
  137. if hasattr(args, DECAY_MOM_RATE) and args.decay_mom_rate is not None:
  138. params[DECAY_MOM_RATE] = args.decay_mom_rate
  139. def override_warmupLR_params(args, params):
  140. if hasattr(args, WARMUP_MIN_LR) and args.warmup_min_lr is not None:
  141. params[WARMUP_MIN_LR] = args.warmup_min_lr
  142. if hasattr(args, WARMUP_MAX_LR) and args.warmup_max_lr is not None:
  143. params[WARMUP_MAX_LR] = args.warmup_max_lr
  144. if hasattr(args, WARMUP_NUM_STEPS) and args.warmup_num_steps is not None:
  145. params[WARMUP_NUM_STEPS] = args.warmup_num_steps
  146. if hasattr(args, WARMUP_TYPE) and args.warmup_type is not None:
  147. params[WARMUP_TYPE] = args.warmup_type
  148. def override_params(args, params):
  149. # LR range test params
  150. override_lr_range_test_params(args, params)
  151. # 1Cycle params
  152. override_1cycle_params(args, params)
  153. # WarmupLR params
  154. override_warmupLR_params(args, params)
  155. def get_config_from_args(args):
  156. if not hasattr(args, LR_SCHEDULE) or args.lr_schedule is None:
  157. return None, '--{} not specified on command line'.format(LR_SCHEDULE)
  158. if not args.lr_schedule in VALID_LR_SCHEDULES:
  159. return None, '{} is not supported LR schedule'.format(args.lr_schedule)
  160. config = {}
  161. config['type'] = args.lr_schedule
  162. config['params'] = {}
  163. if args.lr_schedule == LR_RANGE_TEST:
  164. override_lr_range_test_params(args, config['params'])
  165. elif args.lr_schedule == ONE_CYCLE:
  166. override_1cycle_params(args, config['params'])
  167. else:
  168. override_warmupLR_params(args, config['params'])
  169. return config, None
  170. def get_lr_from_config(config):
  171. if not 'type' in config:
  172. return None, 'LR schedule type not defined in config'
  173. if not 'params' in config:
  174. return None, 'LR schedule params not defined in config'
  175. lr_schedule = config['type']
  176. lr_params = config['params']
  177. if not lr_schedule in VALID_LR_SCHEDULES:
  178. return None, '{} is not a valid LR schedule'.format(lr_schedule)
  179. if lr_schedule == LR_RANGE_TEST:
  180. return lr_params[LR_RANGE_TEST_MIN_LR], ''
  181. if lr_schedule == ONE_CYCLE:
  182. return lr_params[CYCLE_MAX_LR], ''
  183. # Warmup LR
  184. return lr_params[WARMUP_MAX_LR], ''
  185. """
  186. Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped
  187. optimizer to see if requirement is satisfied.
  188. TODO: Looking under the hood to examine the wrapped optimizer is a hack that requires a better long-term fix.
  189. """
  190. def get_torch_optimizer(optimizer):
  191. if isinstance(optimizer, Optimizer):
  192. return optimizer
  193. if hasattr(optimizer, 'optimizer') and isinstance(optimizer.optimizer, Optimizer):
  194. return optimizer.optimizer
  195. raise TypeError('{} is not a subclass of torch.optim.Optimizer'.format(type(optimizer).__name__))
  196. class LRRangeTest(object):
  197. """Sets the learning rate of each parameter group according to
  198. learning rate range test (LRRT) policy. The policy increases learning
  199. rate starting from a base value with a constant frequency, as detailed in
  200. the paper `A disciplined approach to neural network hyper-parameters: Part1`_.
  201. LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to
  202. configure the LR boundaries for Cyclic LR schedules.
  203. LRRT changes the learning rate after every batch.
  204. `step` should be called after a batch has been used for training.
  205. Args:
  206. optimizer (Optimizer): Wrapped optimizer.
  207. lr_range_test_min_lr (float or list): Initial learning rate which is the
  208. lower boundary in the range test for each parameter group.
  209. lr_range_test_step_size (int): Interval of training steps to increase learning rate. Default: 2000
  210. lr_range_test_step_rate (float): Scaling rate for range test. Default: 1.0
  211. lr_range_test_staircase (bool): Scale in staircase fashion, rather than continuous. Default: False.
  212. last_batch_iteration (int): The index of the last batch. This parameter is used when
  213. resuming a training job. Since `step()` should be invoked after each
  214. batch instead of after each epoch, this number represents the total
  215. number of *batches* computed, not the total number of epochs computed.
  216. When last_batch_iteration=-1, the schedule is started from the beginning.
  217. Default: -1
  218. Example:
  219. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  220. >>> scheduler = LRRangeTest(optimizer)
  221. >>> data_loader = torch.utils.data.DataLoader(...)
  222. >>> for epoch in range(10):
  223. >>> for batch in data_loader:
  224. >>> train_batch(...)
  225. >>> scheduler.step()
  226. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay:
  227. https://arxiv.org/abs/1803.09820
  228. """
  229. def __init__(self,
  230. optimizer: Optimizer,
  231. lr_range_test_min_lr: float = 1e-3,
  232. lr_range_test_step_size: int = 2000,
  233. lr_range_test_step_rate: float = 1.0,
  234. lr_range_test_staircase: bool = False,
  235. last_batch_iteration: int = -1):
  236. self.optimizer = get_torch_optimizer(optimizer)
  237. if isinstance(lr_range_test_min_lr, list) or isinstance(lr_range_test_min_lr, tuple):
  238. if len(lr_range_test_min_lr) != len(self.optimizer.param_groups):
  239. raise ValueError("expected {} lr_range_test_min_lr, got {}".format(len(self.optimizer.param_groups),
  240. len(lr_range_test_min_lr)))
  241. self.min_lr = list(lr_range_test_min_lr)
  242. else:
  243. self.min_lr = [lr_range_test_min_lr] * len(self.optimizer.param_groups)
  244. self.step_size = lr_range_test_step_size
  245. self.step_rate = lr_range_test_step_rate
  246. self.last_batch_iteration = last_batch_iteration
  247. self.staircase = lr_range_test_staircase
  248. self.interval_fn = self._staircase_interval if lr_range_test_staircase else self._continuous_interval
  249. if last_batch_iteration == -1:
  250. self._update_optimizer(self.min_lr)
  251. def _staircase_interval(self):
  252. return math.floor(float(self.last_batch_iteration + 1) / self.step_size)
  253. def _continuous_interval(self):
  254. return float(self.last_batch_iteration + 1) / self.step_size
  255. def _get_increase(self):
  256. return (1 + self.step_rate * self.interval_fn())
  257. def get_lr(self):
  258. lr_increase = self._get_increase()
  259. return [lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr]
  260. def get_last_lr(self):
  261. """ Return last computed learning rate by current scheduler.
  262. """
  263. assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
  264. return self._last_lr
  265. def _update_optimizer(self, group_lrs):
  266. for param_group, lr in zip(self.optimizer.param_groups, group_lrs):
  267. param_group['lr'] = lr
  268. def step(self, batch_iteration=None):
  269. if batch_iteration is None:
  270. batch_iteration = self.last_batch_iteration + 1
  271. self.last_batch_iteration = batch_iteration
  272. self._update_optimizer(self.get_lr())
  273. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  274. def state_dict(self):
  275. return {'last_batch_iteration': self.last_batch_iteration}
  276. def load_state_dict(self, sd):
  277. self.last_batch_iteration = sd['last_batch_iteration']
  278. class OneCycle(object):
  279. """Sets the learning rate of each parameter group according to
  280. 1Cycle learning rate policy (1CLR). 1CLR is a variation of the
  281. Cyclical Learning Rate (CLR) policy that involves one cycle followed by
  282. decay. The policy simultaneously cycles the learning rate (and momentum)
  283. between two boundaries with a constant frequency, as detailed in
  284. the paper `A disciplined approach to neural network hyper-parameters`_.
  285. 1CLR policy changes the learning rate after every batch.
  286. `step` should be called after a batch has been used for training.
  287. This implementation was adapted from the github repo: `pytorch/pytorch`_
  288. Args:
  289. optimizer (Optimizer): Wrapped optimizer.
  290. cycle_min_lr (float or list): Initial learning rate which is the
  291. lower boundary in the cycle for each parameter group.
  292. cycle_max_lr (float or list): Upper learning rate boundaries in the cycle
  293. for each parameter group. Functionally,
  294. it defines the cycle amplitude (cycle_max_lr - cycle_min_lr).
  295. The lr at any cycle is the sum of cycle_min_lr
  296. and some scaling of the amplitude; therefore
  297. cycle_max_lr may not actually be reached depending on
  298. scaling function.
  299. decay_lr_rate(float): Decay rate for learning rate. Default: 0.
  300. cycle_first_step_size (int): Number of training iterations in the
  301. increasing half of a cycle. Default: 2000
  302. cycle_second_step_size (int): Number of training iterations in the
  303. decreasing half of a cycle. If cycle_second_step_size is None,
  304. it is set to cycle_first_step_size. Default: None
  305. cycle_first_stair_count(int): Number of stairs in first half of cycle phase. This means
  306. lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
  307. cycle_second_stair_count(int): Number of stairs in second half of cycle phase. This means
  308. lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
  309. decay_step_size (int): Intervals for applying decay in decay phase. Default: 0, means no decay.
  310. cycle_momentum (bool): If ``True``, momentum is cycled inversely
  311. to learning rate between 'cycle_min_mom' and 'cycle_max_mom'.
  312. Default: True
  313. cycle_min_mom (float or list): Initial momentum which is the
  314. lower boundary in the cycle for each parameter group.
  315. Default: 0.8
  316. cycle_max_mom (float or list): Upper momentum boundaries in the cycle
  317. for each parameter group. Functionally,
  318. it defines the cycle amplitude (cycle_max_mom - cycle_min_mom).
  319. The momentum at any cycle is the difference of cycle_max_mom
  320. and some scaling of the amplitude; therefore
  321. cycle_min_mom may not actually be reached depending on
  322. scaling function. Default: 0.9
  323. decay_mom_rate (float): Decay rate for momentum. Default: 0.
  324. last_batch_iteration (int): The index of the last batch. This parameter is used when
  325. resuming a training job. Since `step()` should be invoked after each
  326. batch instead of after each epoch, this number represents the total
  327. number of *batches* computed, not the total number of epochs computed.
  328. When last_batch_iteration=-1, the schedule is started from the beginning.
  329. Default: -1
  330. Example:
  331. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  332. >>> scheduler = OneCycle(optimizer, 0.0001, 0.0010)
  333. >>> data_loader = torch.utils.data.DataLoader(...)
  334. >>> for epoch in range(10):
  335. >>> for batch in data_loader:
  336. >>> train_batch(...)
  337. >>> scheduler.step()
  338. .. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820
  339. """
  340. def __init__(self,
  341. optimizer,
  342. cycle_min_lr,
  343. cycle_max_lr,
  344. decay_lr_rate=0.,
  345. cycle_first_step_size=2000,
  346. cycle_second_step_size=None,
  347. cycle_first_stair_count=0,
  348. cycle_second_stair_count=None,
  349. decay_step_size=0,
  350. cycle_momentum=True,
  351. cycle_min_mom=0.8,
  352. cycle_max_mom=0.9,
  353. decay_mom_rate=0.,
  354. last_batch_iteration=-1):
  355. self.optimizer = get_torch_optimizer(optimizer)
  356. # Initialize cycle shape
  357. self._initialize_cycle(cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count,
  358. cycle_second_stair_count, decay_step_size)
  359. # Initialize cycle lr
  360. self._initialize_lr(self.optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration)
  361. # Initialize cyclic momentum
  362. self.cycle_momentum = cycle_momentum
  363. if cycle_momentum:
  364. self._initialize_momentum(self.optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate,
  365. last_batch_iteration)
  366. # Initialize batch iteration tracker
  367. self.last_batch_iteration = last_batch_iteration
  368. # Configure cycle shape
  369. def _initialize_cycle(self, cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count,
  370. cycle_second_stair_count, decay_step_size):
  371. cycle_first_step_size = float(cycle_first_step_size)
  372. cycle_second_step_size = float(
  373. cycle_second_step_size) if cycle_second_step_size is not None else cycle_first_step_size
  374. self.total_size = cycle_first_step_size + cycle_second_step_size
  375. self.step_ratio = cycle_first_step_size / self.total_size
  376. self.first_stair_count = cycle_first_stair_count
  377. self.second_stair_count = cycle_first_stair_count if cycle_second_stair_count is None else cycle_second_stair_count
  378. self.decay_step_size = decay_step_size
  379. if math.isclose(self.decay_step_size, 0):
  380. self.skip_lr_decay = True
  381. self.skip_mom_decay = True
  382. else:
  383. self.skip_lr_decay = False
  384. self.skip_mom_decay = False
  385. # Configure lr schedule
  386. def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration):
  387. self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups)
  388. if last_batch_iteration == -1:
  389. for lr, group in zip(self.min_lrs, optimizer.param_groups):
  390. group['lr'] = lr
  391. self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups)
  392. self.decay_lr_rate = decay_lr_rate
  393. if math.isclose(self.decay_lr_rate, 0):
  394. self.skip_lr_decay = True
  395. # Configure momentum schedule
  396. def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration):
  397. if 'betas' not in optimizer.defaults:
  398. optimizer_name = type(optimizer).__name__
  399. logger.warn(
  400. f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
  401. )
  402. self.cycle_momentum = False
  403. return
  404. self.decay_mom_rate = decay_mom_rate
  405. self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups)
  406. self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups)
  407. if last_batch_iteration == -1:
  408. for momentum, group in zip(self.min_moms, optimizer.param_groups):
  409. group['betas'] = momentum
  410. if math.isclose(self.decay_mom_rate, 0):
  411. self.skip_mom_decay = True
  412. def _get_scale_factor(self):
  413. batch_iteration = (self.last_batch_iteration + 1)
  414. cycle = math.floor(1 + batch_iteration / self.total_size)
  415. x = 1. + batch_iteration / self.total_size - cycle
  416. if x <= self.step_ratio:
  417. scale_factor = x / self.step_ratio
  418. else:
  419. scale_factor = (x - 1) / (self.step_ratio - 1)
  420. return scale_factor
  421. def _get_cycle_mom(self):
  422. scale_factor = self._get_scale_factor()
  423. momentums = []
  424. for base_betas, max_betas in zip(self.min_moms, self.max_moms):
  425. cycle_min_mom = base_betas[0]
  426. cycle_max_mom = max_betas[0]
  427. base_height = (cycle_max_mom - cycle_min_mom) * scale_factor
  428. momentum = cycle_max_mom - base_height
  429. momentums.append((momentum, base_betas[1]))
  430. return momentums
  431. def _get_cycle_lr(self):
  432. scale_factor = self._get_scale_factor()
  433. lrs = []
  434. for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
  435. base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
  436. lr = cycle_min_lr + base_height
  437. lrs.append(lr)
  438. return lrs
  439. def _get_decay_mom(self, decay_batch_iteration):
  440. if self.skip_mom_decay:
  441. return self.max_moms
  442. decay_interval = decay_batch_iteration / self.decay_step_size
  443. mom_decay_factor = (1 + self.decay_mom_rate * decay_interval)
  444. momentums = [(beta0 * mom_decay_factor, beta1) for beta0, beta1 in self.max_moms]
  445. return momentums
  446. def _get_decay_lr(self, decay_batch_iteration):
  447. """Calculates the learning rate at batch index. This function is used
  448. after the cycle completes and post cycle decaying of lr/mom is enabled.
  449. This function treats `self.last_batch_iteration` as the last batch index.
  450. """
  451. if self.skip_lr_decay:
  452. return self.min_lrs
  453. decay_interval = decay_batch_iteration / self.decay_step_size
  454. lr_decay_factor = (1 + self.decay_lr_rate * decay_interval)
  455. lrs = [cycle_min_lr / lr_decay_factor for cycle_min_lr in self.min_lrs]
  456. return lrs
  457. def get_lr(self):
  458. """Calculates the learning rate at batch index. This function treats
  459. `self.last_batch_iteration` as the last batch index.
  460. """
  461. if self.last_batch_iteration < self.total_size:
  462. return self._get_cycle_lr()
  463. return self._get_decay_lr(self.last_batch_iteration - self.total_size + 1)
  464. def get_mom(self):
  465. """Calculates the momentum at batch index. This function treats
  466. `self.last_batch_iteration` as the last batch index.
  467. """
  468. if not self.cycle_momentum:
  469. return None
  470. if self.last_batch_iteration < self.total_size:
  471. return self._get_cycle_mom()
  472. return self._get_decay_mom(self.last_batch_iteration - self.total_size + 1)
  473. def get_last_lr(self):
  474. """ Return last computed learning rate by current scheduler.
  475. """
  476. assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
  477. return self._last_lr
  478. def step(self, batch_iteration=None):
  479. """ Updates the optimizer with the learning rate for the last batch index.
  480. `self.last_batch_iteration` is treated as the last batch index.
  481. If self.cycle_momentum is true, also updates optimizer momentum.
  482. """
  483. if batch_iteration is None:
  484. batch_iteration = self.last_batch_iteration + 1
  485. self.last_batch_iteration = batch_iteration
  486. for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
  487. param_group['lr'] = lr
  488. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  489. if self.cycle_momentum:
  490. momentums = self.get_mom()
  491. for param_group, momentum in zip(self.optimizer.param_groups, momentums):
  492. param_group['betas'] = momentum
  493. def state_dict(self):
  494. return {'last_batch_iteration': self.last_batch_iteration}
  495. def load_state_dict(self, sd):
  496. self.last_batch_iteration = sd['last_batch_iteration']
  497. class WarmupLR(object):
  498. """Increase the learning rate of each parameter group from min lr to max lr
  499. over warmup_num_steps steps, and then fix at max lr.
  500. Args:
  501. optimizer (Optimizer): Wrapped optimizer.
  502. warmup_min_lr (float or list): minimum learning rate. Default: 0
  503. warmup_max_lr (float or list): maximum learning rate. Default: 0.001
  504. warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
  505. warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
  506. last_batch_iteration (int): The index of the last batch. Default: -1.
  507. Example:
  508. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  509. >>> scheduler = WarmupLR(optimizer)
  510. >>> data_loader = torch.utils.data.DataLoader(...)
  511. >>> for epoch in range(10):
  512. >>> for batch in data_loader:
  513. >>> train_batch(...)
  514. >>> scheduler.step()
  515. """
  516. def __init__(self,
  517. optimizer: Optimizer,
  518. warmup_min_lr: float = 0.0,
  519. warmup_max_lr: float = 0.001,
  520. warmup_num_steps: int = 1000,
  521. warmup_type: str = WARMUP_LOG_RATE,
  522. last_batch_iteration: int = -1):
  523. self.optimizer = get_torch_optimizer(optimizer)
  524. self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr")
  525. self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr")
  526. self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)]
  527. self.warmup_num_steps = max(2, warmup_num_steps)
  528. # Currently only support linear and log function
  529. if warmup_type not in {WARMUP_LOG_RATE, WARMUP_LINEAR_RATE}:
  530. logger.warning(f"Using unknown warmup_type: {warmup_type}. The increasing function "
  531. f"is set to default (log)")
  532. warmup_type = WARMUP_LOG_RATE
  533. self.warmup_type = warmup_type
  534. self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
  535. self.last_batch_iteration = last_batch_iteration
  536. def get_lr(self):
  537. if self.last_batch_iteration < 0:
  538. logger.warning("Attempting to get learning rate from scheduler before it has started")
  539. return [0.0]
  540. gamma = self._get_gamma()
  541. return [min_lr + (delta_lr * gamma) for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)]
  542. def get_last_lr(self):
  543. """ Return last computed learning rate by current scheduler.
  544. """
  545. assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
  546. return self._last_lr
  547. def step(self, last_batch_iteration=None):
  548. if last_batch_iteration is None:
  549. last_batch_iteration = self.last_batch_iteration + 1
  550. self.last_batch_iteration = last_batch_iteration
  551. for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
  552. param_group['lr'] = lr
  553. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  554. def state_dict(self):
  555. return {'last_batch_iteration': self.last_batch_iteration}
  556. def load_state_dict(self, sd):
  557. self.last_batch_iteration = sd['last_batch_iteration']
  558. def _get_gamma(self):
  559. if self.last_batch_iteration < self.warmup_num_steps:
  560. if self.warmup_type == WARMUP_LOG_RATE:
  561. return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
  562. elif self.warmup_type == WARMUP_LINEAR_RATE:
  563. return self.last_batch_iteration / self.warmup_num_steps
  564. return 1.0
  565. def _format_param(self, optimizer, param_value, param_name):
  566. if isinstance(param_value, list) or isinstance(param_value, tuple):
  567. if len(param_value) != len(optimizer.param_groups):
  568. raise ValueError("expected {} value for {}, got {}".format(len(optimizer.param_groups), param_name,
  569. FileNotFoundError(param_value)))
  570. return list(param_value)
  571. return [param_value] * len(optimizer.param_groups)
  572. class WarmupDecayLR(WarmupLR):
  573. """Increase the learning rate of each parameter group from min lr to max lr
  574. over warmup_num_steps steps, and then decay at linear rate over the remaining training steps.
  575. Args:
  576. optimizer (Optimizer): Wrapped optimizer.
  577. total_num_steps (int): total number of training steps
  578. warmup_min_lr (float or list): minimum learning rate. Default: 0
  579. warmup_max_lr (float or list): maximum learning rate. Default: 0.001
  580. warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
  581. warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
  582. last_batch_iteration (int): The index of the last batch. Default: -1.
  583. Example:
  584. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  585. >>> scheduler = WarmupDecayLR(optimizer, 1000000)
  586. >>> data_loader = torch.utils.data.DataLoader(...)
  587. >>> for epoch in range(10):
  588. >>> for batch in data_loader:
  589. >>> train_batch(...)
  590. >>> scheduler.step()
  591. """
  592. def __init__(self,
  593. optimizer: Optimizer,
  594. total_num_steps: int,
  595. warmup_min_lr: float = 0.0,
  596. warmup_max_lr: float = 0.001,
  597. warmup_num_steps: int = 1000,
  598. warmup_type: str = WARMUP_LOG_RATE,
  599. last_batch_iteration: int = -1):
  600. self.total_num_steps = total_num_steps
  601. super(WarmupDecayLR, self).__init__(optimizer, warmup_min_lr, warmup_max_lr, warmup_num_steps, warmup_type,
  602. last_batch_iteration)
  603. if self.total_num_steps < self.warmup_num_steps:
  604. logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format(
  605. total_num_steps, warmup_num_steps))
  606. def _get_gamma(self):
  607. if self.last_batch_iteration < self.warmup_num_steps:
  608. if self.warmup_type == WARMUP_LOG_RATE:
  609. return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
  610. elif self.warmup_type == WARMUP_LINEAR_RATE:
  611. return self.last_batch_iteration / self.warmup_num_steps
  612. return max(
  613. 0.0,
  614. float(self.total_num_steps - self.last_batch_iteration) /
  615. float(max(1.0, self.total_num_steps - self.warmup_num_steps)))
  616. class WarmupCosineLR(object):
  617. """Increase the learning rate of each parameter group from min lr ratio to max lr ratio
  618. over warmup_num_steps steps, and then decay at cosine rate over the remaining training steps to min cosine ratio.
  619. Args:
  620. optimizer (Optimizer): Wrapped optimizer.
  621. total_num_steps (int): total number of training steps
  622. warmup_min_ratio (float or list): warmup start learning rate ratio. Default: 0
  623. warmup_num_steps (int): number of steps to warm up from warmup_min_ratio to 1.0. Default: 1000
  624. warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
  625. cos_min_ratio (float): cosine end learning rate ratio. Default: 0.0001
  626. last_batch_iteration (int): The index of the last batch. Default: -1.
  627. Example:
  628. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  629. >>> scheduler = WarmupCosineLR(optimizer, 1000000)
  630. >>> data_loader = torch.utils.data.DataLoader(...)
  631. >>> for epoch in range(10):
  632. >>> for batch in data_loader:
  633. >>> train_batch(...)
  634. >>> scheduler.step()
  635. """
  636. def __init__(self,
  637. optimizer: Optimizer,
  638. total_num_steps: int,
  639. warmup_min_ratio: float = 0.0,
  640. warmup_num_steps: int = 1000,
  641. cos_min_ratio: float = 0.0001,
  642. warmup_type: str = WARMUP_LOG_RATE,
  643. last_batch_iteration: int = -1):
  644. self.optimizer = get_torch_optimizer(optimizer)
  645. self.total_num_steps = total_num_steps
  646. self.last_batch_iteration = last_batch_iteration
  647. self.cos_min_ratio = cos_min_ratio
  648. self.warmup_type = warmup_type
  649. self.warmup_min_ratio = warmup_min_ratio
  650. self.warmup_num_steps = max(2, warmup_num_steps)
  651. self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
  652. if self.total_num_steps < self.warmup_num_steps:
  653. logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format(
  654. total_num_steps, warmup_num_steps))
  655. self.org_lrs = [group['lr'] for group in self.optimizer.param_groups]
  656. def get_lr_ratio(self):
  657. if self.last_batch_iteration < 0:
  658. logger.warning("Attempting to get learning rate from scheduler before it has started")
  659. return [0.0]
  660. if self.last_batch_iteration < self.warmup_num_steps:
  661. if self.warmup_type == WARMUP_LOG_RATE:
  662. ratio = self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
  663. elif self.warmup_type == WARMUP_LINEAR_RATE:
  664. ratio = self.last_batch_iteration / self.warmup_num_steps
  665. ratio_delta = 1. - self.warmup_min_ratio
  666. ratio = self.warmup_min_ratio + ratio * ratio_delta
  667. return ratio
  668. real_last_step = self.last_batch_iteration - self.warmup_num_steps + 1
  669. real_total_steps = self.total_num_steps - self.warmup_num_steps
  670. ratio_delta = 1. - self.cos_min_ratio
  671. ratio = (1 + math.cos(math.pi * real_last_step / real_total_steps)) / 2
  672. ratio = max(0.0, self.cos_min_ratio + ratio_delta * ratio)
  673. return ratio
  674. def step(self, last_batch_iteration=None):
  675. if last_batch_iteration is None:
  676. last_batch_iteration = self.last_batch_iteration + 1
  677. self.last_batch_iteration = last_batch_iteration
  678. lrs = self.get_lr()
  679. for param_group, lr in zip(self.optimizer.param_groups, lrs):
  680. param_group['lr'] = lr
  681. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  682. def get_lr(self):
  683. if self.last_batch_iteration < 0:
  684. logger.warning("Attempting to get learning rate from scheduler before it has started")
  685. return [0.0]
  686. lr_ratio = self.get_lr_ratio()
  687. return [org_lr * lr_ratio for org_lr in self.org_lrs]
  688. def get_last_lr(self):
  689. """ Return last computed learning rate by current scheduler.
  690. """
  691. assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
  692. return self._last_lr
  693. def state_dict(self):
  694. return {'last_batch_iteration': self.last_batch_iteration}
  695. def load_state_dict(self, sd):
  696. self.last_batch_iteration = sd['last_batch_iteration']
  697. def _format_param(self, optimizer, param_value, param_name):
  698. if isinstance(param_value, list) or isinstance(param_value, tuple):
  699. if len(param_value) != len(optimizer.param_groups):
  700. raise ValueError("expected {} value for {}, got {}".format(len(optimizer.param_groups), param_name,
  701. FileNotFoundError(param_value)))
  702. return list(param_value)
  703. return [param_value] * len(optimizer.param_groups)