lr_schedules.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Implementation of learning rate schedules.
  6. Taken and modified from PyTorch v1.0.1 source
  7. https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py
  8. """
  9. import argparse
  10. from torch.optim import Optimizer
  11. import math
  12. from deepspeed.utils import logger
  13. LR_SCHEDULE = 'lr_schedule'
  14. LR_RANGE_TEST = 'LRRangeTest'
  15. ONE_CYCLE = 'OneCycle'
  16. WARMUP_LR = 'WarmupLR'
  17. WARMUP_DECAY_LR = 'WarmupDecayLR'
  18. VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR]
  19. LR_RANGE_TEST_MIN_LR = 'lr_range_test_min_lr'
  20. LR_RANGE_TEST_STEP_RATE = 'lr_range_test_step_rate'
  21. LR_RANGE_TEST_STEP_SIZE = 'lr_range_test_step_size'
  22. LR_RANGE_TEST_STAIRCASE = 'lr_range_test_staircase'
  23. EDGE_VALUE = 'edge_value'
  24. MID_VALUE = 'mid_value'
  25. CYCLE_FIRST_STEP_SIZE = 'cycle_first_step_size'
  26. CYCLE_FIRST_STAIR_COUNT = 'cycle_first_stair_count'
  27. CYCLE_SECOND_STEP_SIZE = 'cycle_second_step_size'
  28. CYCLE_SECOND_STAIR_COUNT = 'cycle_second_stair_count'
  29. DECAY_STEP_SIZE = 'decay_step_size'
  30. CYCLE_MIN_LR = 'cycle_min_lr'
  31. CYCLE_MAX_LR = 'cycle_max_lr'
  32. DECAY_LR_RATE = 'decay_lr_rate'
  33. CYCLE_MIN_MOM = 'cycle_min_mom'
  34. CYCLE_MAX_MOM = 'cycle_max_mom'
  35. DECAY_MOM_RATE = 'decay_mom_rate'
  36. WARMUP_MIN_LR = 'warmup_min_lr'
  37. WARMUP_MAX_LR = 'warmup_max_lr'
  38. WARMUP_NUM_STEPS = 'warmup_num_steps'
  39. WARMUP_TYPE = 'warmup_type'
  40. WARMUP_LOG_RATE = 'log'
  41. WARMUP_LINEAR_RATE = 'linear'
  42. TOTAL_NUM_STEPS = 'total_num_steps'
  43. def add_tuning_arguments(parser):
  44. group = parser.add_argument_group('Convergence Tuning', 'Convergence tuning configurations')
  45. # LR scheduler
  46. group.add_argument('--lr_schedule', type=str, default=None, help='LR schedule for training.')
  47. # Learning rate range test
  48. group.add_argument("--lr_range_test_min_lr", type=float, default=0.001, help='Starting lr value.')
  49. group.add_argument("--lr_range_test_step_rate", type=float, default=1.0, help='scaling rate for LR range test.')
  50. group.add_argument("--lr_range_test_step_size", type=int, default=1000, help='training steps per LR change.')
  51. group.add_argument("--lr_range_test_staircase",
  52. type=bool,
  53. default=False,
  54. help='use staircase scaling for LR range test.')
  55. # OneCycle schedule
  56. group.add_argument("--cycle_first_step_size",
  57. type=int,
  58. default=1000,
  59. help='size of first step of 1Cycle schedule (training steps).')
  60. group.add_argument("--cycle_first_stair_count",
  61. type=int,
  62. default=-1,
  63. help='first stair count for 1Cycle schedule.')
  64. group.add_argument("--cycle_second_step_size",
  65. type=int,
  66. default=-1,
  67. help='size of second step of 1Cycle schedule (default first_step_size).')
  68. group.add_argument("--cycle_second_stair_count",
  69. type=int,
  70. default=-1,
  71. help='second stair count for 1Cycle schedule.')
  72. group.add_argument("--decay_step_size",
  73. type=int,
  74. default=1000,
  75. help='size of intervals for applying post cycle decay (training steps).')
  76. # 1Cycle LR
  77. group.add_argument("--cycle_min_lr", type=float, default=0.01, help='1Cycle LR lower bound.')
  78. group.add_argument("--cycle_max_lr", type=float, default=0.1, help='1Cycle LR upper bound.')
  79. group.add_argument("--decay_lr_rate", type=float, default=0.0, help='post cycle LR decay rate.')
  80. # 1Cycle Momentum
  81. group.add_argument('--cycle_momentum', default=False, action='store_true', help='Enable 1Cycle momentum schedule.')
  82. group.add_argument("--cycle_min_mom", type=float, default=0.8, help='1Cycle momentum lower bound.')
  83. group.add_argument("--cycle_max_mom", type=float, default=0.9, help='1Cycle momentum upper bound.')
  84. group.add_argument("--decay_mom_rate", type=float, default=0.0, help='post cycle momentum decay rate.')
  85. # Warmup LR
  86. group.add_argument('--warmup_min_lr', type=float, default=0, help='WarmupLR minimum/initial LR value')
  87. group.add_argument('--warmup_max_lr', type=float, default=0.001, help='WarmupLR maximum LR value.')
  88. group.add_argument('--warmup_num_steps', type=int, default=1000, help='WarmupLR step count for LR warmup.')
  89. group.add_argument('--warmup_type',
  90. type=str,
  91. default=WARMUP_LOG_RATE,
  92. help='WarmupLR increasing function during warmup')
  93. return parser
  94. def parse_arguments():
  95. parser = argparse.ArgumentParser()
  96. parser = add_tuning_arguments(parser)
  97. lr_sched_args, unknown_args = parser.parse_known_args()
  98. return lr_sched_args, unknown_args
  99. def override_lr_range_test_params(args, params):
  100. if hasattr(args, LR_RANGE_TEST_MIN_LR) and args.lr_range_test_min_lr is not None:
  101. params[LR_RANGE_TEST_MIN_LR] = args.lr_range_test_min_lr
  102. if hasattr(args, LR_RANGE_TEST_STEP_RATE) and args.lr_range_test_step_rate is not None:
  103. params[LR_RANGE_TEST_STEP_RATE] = args.lr_range_test_step_rate
  104. if hasattr(args, LR_RANGE_TEST_STEP_SIZE) and args.lr_range_test_step_size is not None:
  105. params[LR_RANGE_TEST_STEP_SIZE] = args.lr_range_test_step_size
  106. if hasattr(args, LR_RANGE_TEST_STAIRCASE) and args.lr_range_test_staircase is not None:
  107. params[LR_RANGE_TEST_STAIRCASE] = args.lr_range_test_staircase
  108. def override_1cycle_params(args, params):
  109. if hasattr(args, CYCLE_FIRST_STEP_SIZE) and args.cycle_first_step_size is not None:
  110. params[CYCLE_FIRST_STEP_SIZE] = args.cycle_first_step_size
  111. if hasattr(args, CYCLE_FIRST_STAIR_COUNT) and args.cycle_first_stair_count is not None:
  112. params[CYCLE_FIRST_STAIR_COUNT] = args.cycle_first_stair_count
  113. if hasattr(args, CYCLE_SECOND_STEP_SIZE) and args.cycle_second_step_size is not None:
  114. params[CYCLE_SECOND_STEP_SIZE] = args.cycle_second_step_size
  115. if hasattr(args, CYCLE_SECOND_STAIR_COUNT) and args.cycle_second_stair_count is not None:
  116. params[CYCLE_SECOND_STAIR_COUNT] = args.cycle_second_stair_count
  117. if hasattr(args, DECAY_STEP_SIZE) and args.decay_step_size is not None:
  118. params[DECAY_STEP_SIZE] = args.decay_step_size
  119. # 1Cycle LR params
  120. if hasattr(args, CYCLE_MIN_LR) and args.cycle_min_lr is not None:
  121. params[CYCLE_MIN_LR] = args.cycle_min_lr
  122. if hasattr(args, CYCLE_MAX_LR) and args.cycle_max_lr is not None:
  123. params[CYCLE_MAX_LR] = args.cycle_max_lr
  124. if hasattr(args, DECAY_LR_RATE) and args.decay_lr_rate is not None:
  125. params[DECAY_LR_RATE] = args.decay_lr_rate
  126. # 1Cycle MOM params
  127. if hasattr(args, CYCLE_MIN_MOM) and args.cycle_min_mom is not None:
  128. params[CYCLE_MIN_MOM] = args.cycle_min_mom
  129. if hasattr(args, CYCLE_MAX_MOM) and args.cycle_max_mom is not None:
  130. params[CYCLE_MAX_MOM] = args.cycle_max_mom
  131. if hasattr(args, DECAY_MOM_RATE) and args.decay_mom_rate is not None:
  132. params[DECAY_MOM_RATE] = args.decay_mom_rate
  133. def override_warmupLR_params(args, params):
  134. if hasattr(args, WARMUP_MIN_LR) and args.warmup_min_lr is not None:
  135. params[WARMUP_MIN_LR] = args.warmup_min_lr
  136. if hasattr(args, WARMUP_MAX_LR) and args.warmup_max_lr is not None:
  137. params[WARMUP_MAX_LR] = args.warmup_max_lr
  138. if hasattr(args, WARMUP_NUM_STEPS) and args.warmup_num_steps is not None:
  139. params[WARMUP_NUM_STEPS] = args.warmup_num_steps
  140. if hasattr(args, WARMUP_TYPE) and args.warmup_type is not None:
  141. params[WARMUP_TYPE] = args.warmup_type
  142. def override_params(args, params):
  143. # LR range test params
  144. override_lr_range_test_params(args, params)
  145. # 1Cycle params
  146. override_1cycle_params(args, params)
  147. # WarmupLR params
  148. override_warmupLR_params(args, params)
  149. def get_config_from_args(args):
  150. if not hasattr(args, LR_SCHEDULE) or args.lr_schedule is None:
  151. return None, '--{} not specified on command line'.format(LR_SCHEDULE)
  152. if not args.lr_schedule in VALID_LR_SCHEDULES:
  153. return None, '{} is not supported LR schedule'.format(args.lr_schedule)
  154. config = {}
  155. config['type'] = args.lr_schedule
  156. config['params'] = {}
  157. if args.lr_schedule == LR_RANGE_TEST:
  158. override_lr_range_test_params(args, config['params'])
  159. elif args.lr_schedule == ONE_CYCLE:
  160. override_1cycle_params(args, config['params'])
  161. else:
  162. override_warmupLR_params(args, config['params'])
  163. return config, None
  164. def get_lr_from_config(config):
  165. if not 'type' in config:
  166. return None, 'LR schedule type not defined in config'
  167. if not 'params' in config:
  168. return None, 'LR schedule params not defined in config'
  169. lr_schedule = config['type']
  170. lr_params = config['params']
  171. if not lr_schedule in VALID_LR_SCHEDULES:
  172. return None, '{} is not a valid LR schedule'.format(lr_schedule)
  173. if lr_schedule == LR_RANGE_TEST:
  174. return lr_params[LR_RANGE_TEST_MIN_LR], ''
  175. if lr_schedule == ONE_CYCLE:
  176. return lr_params[CYCLE_MAX_LR], ''
  177. # Warmup LR
  178. return lr_params[WARMUP_MAX_LR], ''
  179. """
  180. Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped
  181. optimizer to see if requirement is satisfied.
  182. TODO: Looking under the hood to examine the wrapped optimizer is a hack that requires a better long-term fix.
  183. """
  184. def get_torch_optimizer(optimizer):
  185. if isinstance(optimizer, Optimizer):
  186. return optimizer
  187. if hasattr(optimizer, 'optimizer') and isinstance(optimizer.optimizer, Optimizer):
  188. return optimizer.optimizer
  189. raise TypeError('{} is not a subclass of torch.optim.Optimizer'.format(type(optimizer).__name__))
  190. class LRRangeTest(object):
  191. """Sets the learning rate of each parameter group according to
  192. learning rate range test (LRRT) policy. The policy increases learning
  193. rate starting from a base value with a constant frequency, as detailed in
  194. the paper `A disciplined approach to neural network hyper-parameters: Part1`_.
  195. LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to
  196. configure the LR boundaries for Cyclic LR schedules.
  197. LRRT changes the learning rate after every batch.
  198. `step` should be called after a batch has been used for training.
  199. Args:
  200. optimizer (Optimizer): Wrapped optimizer.
  201. lr_range_test_min_lr (float or list): Initial learning rate which is the
  202. lower boundary in the range test for each parameter group.
  203. lr_range_test_step_size (int): Interval of training steps to increase learning rate. Default: 2000
  204. lr_range_test_step_rate (float): Scaling rate for range test. Default: 1.0
  205. lr_range_test_staircase (bool): Scale in staircase fashion, rather than continuous. Default: False.
  206. last_batch_iteration (int): The index of the last batch. This parameter is used when
  207. resuming a training job. Since `step()` should be invoked after each
  208. batch instead of after each epoch, this number represents the total
  209. number of *batches* computed, not the total number of epochs computed.
  210. When last_batch_iteration=-1, the schedule is started from the beginning.
  211. Default: -1
  212. Example:
  213. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  214. >>> scheduler = LRRangeTest(optimizer)
  215. >>> data_loader = torch.utils.data.DataLoader(...)
  216. >>> for epoch in range(10):
  217. >>> for batch in data_loader:
  218. >>> train_batch(...)
  219. >>> scheduler.step()
  220. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay:
  221. https://arxiv.org/abs/1803.09820
  222. """
  223. def __init__(self,
  224. optimizer: Optimizer,
  225. lr_range_test_min_lr: float = 1e-3,
  226. lr_range_test_step_size: int = 2000,
  227. lr_range_test_step_rate: float = 1.0,
  228. lr_range_test_staircase: bool = False,
  229. last_batch_iteration: int = -1):
  230. self.optimizer = get_torch_optimizer(optimizer)
  231. if isinstance(lr_range_test_min_lr, list) or isinstance(lr_range_test_min_lr, tuple):
  232. if len(lr_range_test_min_lr) != len(self.optimizer.param_groups):
  233. raise ValueError("expected {} lr_range_test_min_lr, got {}".format(len(self.optimizer.param_groups),
  234. len(lr_range_test_min_lr)))
  235. self.min_lr = list(lr_range_test_min_lr)
  236. else:
  237. self.min_lr = [lr_range_test_min_lr] * len(self.optimizer.param_groups)
  238. self.step_size = lr_range_test_step_size
  239. self.step_rate = lr_range_test_step_rate
  240. self.last_batch_iteration = last_batch_iteration
  241. self.staircase = lr_range_test_staircase
  242. self.interval_fn = self._staircase_interval if lr_range_test_staircase else self._continuous_interval
  243. if last_batch_iteration == -1:
  244. self._update_optimizer(self.min_lr)
  245. def _staircase_interval(self):
  246. return math.floor(float(self.last_batch_iteration + 1) / self.step_size)
  247. def _continuous_interval(self):
  248. return float(self.last_batch_iteration + 1) / self.step_size
  249. def _get_increase(self):
  250. return (1 + self.step_rate * self.interval_fn())
  251. def get_lr(self):
  252. lr_increase = self._get_increase()
  253. return [lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr]
  254. def get_last_lr(self):
  255. """ Return last computed learning rate by current scheduler.
  256. """
  257. assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
  258. return self._last_lr
  259. def _update_optimizer(self, group_lrs):
  260. for param_group, lr in zip(self.optimizer.param_groups, group_lrs):
  261. param_group['lr'] = lr
  262. def step(self, batch_iteration=None):
  263. if batch_iteration is None:
  264. batch_iteration = self.last_batch_iteration + 1
  265. self.last_batch_iteration = batch_iteration
  266. self._update_optimizer(self.get_lr())
  267. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  268. def state_dict(self):
  269. return {'last_batch_iteration': self.last_batch_iteration}
  270. def load_state_dict(self, sd):
  271. self.last_batch_iteration = sd['last_batch_iteration']
  272. class OneCycle(object):
  273. """Sets the learning rate of each parameter group according to
  274. 1Cycle learning rate policy (1CLR). 1CLR is a variation of the
  275. Cyclical Learning Rate (CLR) policy that involves one cycle followed by
  276. decay. The policy simultaneously cycles the learning rate (and momentum)
  277. between two boundaries with a constant frequency, as detailed in
  278. the paper `A disciplined approach to neural network hyper-parameters`_.
  279. 1CLR policy changes the learning rate after every batch.
  280. `step` should be called after a batch has been used for training.
  281. This implementation was adapted from the github repo: `pytorch/pytorch`_
  282. Args:
  283. optimizer (Optimizer): Wrapped optimizer.
  284. cycle_min_lr (float or list): Initial learning rate which is the
  285. lower boundary in the cycle for each parameter group.
  286. cycle_max_lr (float or list): Upper learning rate boundaries in the cycle
  287. for each parameter group. Functionally,
  288. it defines the cycle amplitude (cycle_max_lr - cycle_min_lr).
  289. The lr at any cycle is the sum of cycle_min_lr
  290. and some scaling of the amplitude; therefore
  291. cycle_max_lr may not actually be reached depending on
  292. scaling function.
  293. decay_lr_rate(float): Decay rate for learning rate. Default: 0.
  294. cycle_first_step_size (int): Number of training iterations in the
  295. increasing half of a cycle. Default: 2000
  296. cycle_second_step_size (int): Number of training iterations in the
  297. decreasing half of a cycle. If cycle_second_step_size is None,
  298. it is set to cycle_first_step_size. Default: None
  299. cycle_first_stair_count(int): Number of stairs in first half of cycle phase. This means
  300. lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
  301. cycle_second_stair_count(int): Number of stairs in second half of cycle phase. This means
  302. lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
  303. decay_step_size (int): Intervals for applying decay in decay phase. Default: 0, means no decay.
  304. cycle_momentum (bool): If ``True``, momentum is cycled inversely
  305. to learning rate between 'cycle_min_mom' and 'cycle_max_mom'.
  306. Default: True
  307. cycle_min_mom (float or list): Initial momentum which is the
  308. lower boundary in the cycle for each parameter group.
  309. Default: 0.8
  310. cycle_max_mom (float or list): Upper momentum boundaries in the cycle
  311. for each parameter group. Functionally,
  312. it defines the cycle amplitude (cycle_max_mom - cycle_min_mom).
  313. The momentum at any cycle is the difference of cycle_max_mom
  314. and some scaling of the amplitude; therefore
  315. cycle_min_mom may not actually be reached depending on
  316. scaling function. Default: 0.9
  317. decay_mom_rate (float): Decay rate for momentum. Default: 0.
  318. last_batch_iteration (int): The index of the last batch. This parameter is used when
  319. resuming a training job. Since `step()` should be invoked after each
  320. batch instead of after each epoch, this number represents the total
  321. number of *batches* computed, not the total number of epochs computed.
  322. When last_batch_iteration=-1, the schedule is started from the beginning.
  323. Default: -1
  324. Example:
  325. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  326. >>> scheduler = OneCycle(optimizer, 0.0001, 0.0010)
  327. >>> data_loader = torch.utils.data.DataLoader(...)
  328. >>> for epoch in range(10):
  329. >>> for batch in data_loader:
  330. >>> train_batch(...)
  331. >>> scheduler.step()
  332. .. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820
  333. """
  334. def __init__(self,
  335. optimizer,
  336. cycle_min_lr,
  337. cycle_max_lr,
  338. decay_lr_rate=0.,
  339. cycle_first_step_size=2000,
  340. cycle_second_step_size=None,
  341. cycle_first_stair_count=0,
  342. cycle_second_stair_count=None,
  343. decay_step_size=0,
  344. cycle_momentum=True,
  345. cycle_min_mom=0.8,
  346. cycle_max_mom=0.9,
  347. decay_mom_rate=0.,
  348. last_batch_iteration=-1):
  349. self.optimizer = get_torch_optimizer(optimizer)
  350. # Initialize cycle shape
  351. self._initialize_cycle(cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count,
  352. cycle_second_stair_count, decay_step_size)
  353. # Initialize cycle lr
  354. self._initialize_lr(self.optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration)
  355. # Initialize cyclic momentum
  356. self.cycle_momentum = cycle_momentum
  357. if cycle_momentum:
  358. self._initialize_momentum(self.optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate,
  359. last_batch_iteration)
  360. # Initialize batch iteration tracker
  361. self.last_batch_iteration = last_batch_iteration
  362. # Configure cycle shape
  363. def _initialize_cycle(self, cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count,
  364. cycle_second_stair_count, decay_step_size):
  365. cycle_first_step_size = float(cycle_first_step_size)
  366. cycle_second_step_size = float(
  367. cycle_second_step_size) if cycle_second_step_size is not None else cycle_first_step_size
  368. self.total_size = cycle_first_step_size + cycle_second_step_size
  369. self.step_ratio = cycle_first_step_size / self.total_size
  370. self.first_stair_count = cycle_first_stair_count
  371. self.second_stair_count = cycle_first_stair_count if cycle_second_stair_count is None else cycle_second_stair_count
  372. self.decay_step_size = decay_step_size
  373. if math.isclose(self.decay_step_size, 0):
  374. self.skip_lr_decay = True
  375. self.skip_mom_decay = True
  376. else:
  377. self.skip_lr_decay = False
  378. self.skip_mom_decay = False
  379. # Configure lr schedule
  380. def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration):
  381. self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups)
  382. if last_batch_iteration == -1:
  383. for lr, group in zip(self.min_lrs, optimizer.param_groups):
  384. group['lr'] = lr
  385. self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups)
  386. self.decay_lr_rate = decay_lr_rate
  387. if math.isclose(self.decay_lr_rate, 0):
  388. self.skip_lr_decay = True
  389. # Configure momentum schedule
  390. def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration):
  391. if 'betas' not in optimizer.defaults:
  392. optimizer_name = type(optimizer).__name__
  393. logger.warn(
  394. f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
  395. )
  396. self.cycle_momentum = False
  397. return
  398. self.decay_mom_rate = decay_mom_rate
  399. self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups)
  400. self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups)
  401. if last_batch_iteration == -1:
  402. for momentum, group in zip(self.min_moms, optimizer.param_groups):
  403. group['betas'] = momentum
  404. if math.isclose(self.decay_mom_rate, 0):
  405. self.skip_mom_decay = True
  406. def _get_scale_factor(self):
  407. batch_iteration = (self.last_batch_iteration + 1)
  408. cycle = math.floor(1 + batch_iteration / self.total_size)
  409. x = 1. + batch_iteration / self.total_size - cycle
  410. if x <= self.step_ratio:
  411. scale_factor = x / self.step_ratio
  412. else:
  413. scale_factor = (x - 1) / (self.step_ratio - 1)
  414. return scale_factor
  415. def _get_cycle_mom(self):
  416. scale_factor = self._get_scale_factor()
  417. momentums = []
  418. for base_betas, max_betas in zip(self.min_moms, self.max_moms):
  419. cycle_min_mom = base_betas[0]
  420. cycle_max_mom = max_betas[0]
  421. base_height = (cycle_max_mom - cycle_min_mom) * scale_factor
  422. momentum = cycle_max_mom - base_height
  423. momentums.append((momentum, base_betas[1]))
  424. return momentums
  425. def _get_cycle_lr(self):
  426. scale_factor = self._get_scale_factor()
  427. lrs = []
  428. for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
  429. base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
  430. lr = cycle_min_lr + base_height
  431. lrs.append(lr)
  432. return lrs
  433. def _get_decay_mom(self, decay_batch_iteration):
  434. if self.skip_mom_decay:
  435. return self.max_moms
  436. decay_interval = decay_batch_iteration / self.decay_step_size
  437. mom_decay_factor = (1 + self.decay_mom_rate * decay_interval)
  438. momentums = [(beta0 * mom_decay_factor, beta1) for beta0, beta1 in self.max_moms]
  439. return momentums
  440. def _get_decay_lr(self, decay_batch_iteration):
  441. """Calculates the learning rate at batch index. This function is used
  442. after the cycle completes and post cycle decaying of lr/mom is enabled.
  443. This function treats `self.last_batch_iteration` as the last batch index.
  444. """
  445. if self.skip_lr_decay:
  446. return self.min_lrs
  447. decay_interval = decay_batch_iteration / self.decay_step_size
  448. lr_decay_factor = (1 + self.decay_lr_rate * decay_interval)
  449. lrs = [cycle_min_lr / lr_decay_factor for cycle_min_lr in self.min_lrs]
  450. return lrs
  451. def get_lr(self):
  452. """Calculates the learning rate at batch index. This function treats
  453. `self.last_batch_iteration` as the last batch index.
  454. """
  455. if self.last_batch_iteration < self.total_size:
  456. return self._get_cycle_lr()
  457. return self._get_decay_lr(self.last_batch_iteration - self.total_size + 1)
  458. def get_mom(self):
  459. """Calculates the momentum at batch index. This function treats
  460. `self.last_batch_iteration` as the last batch index.
  461. """
  462. if not self.cycle_momentum:
  463. return None
  464. if self.last_batch_iteration < self.total_size:
  465. return self._get_cycle_mom()
  466. return self._get_decay_mom(self.last_batch_iteration - self.total_size + 1)
  467. def get_last_lr(self):
  468. """ Return last computed learning rate by current scheduler.
  469. """
  470. assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
  471. return self._last_lr
  472. def step(self, batch_iteration=None):
  473. """ Updates the optimizer with the learning rate for the last batch index.
  474. `self.last_batch_iteration` is treated as the last batch index.
  475. If self.cycle_momentum is true, also updates optimizer momentum.
  476. """
  477. if batch_iteration is None:
  478. batch_iteration = self.last_batch_iteration + 1
  479. self.last_batch_iteration = batch_iteration
  480. for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
  481. param_group['lr'] = lr
  482. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  483. if self.cycle_momentum:
  484. momentums = self.get_mom()
  485. for param_group, momentum in zip(self.optimizer.param_groups, momentums):
  486. param_group['betas'] = momentum
  487. def state_dict(self):
  488. return {'last_batch_iteration': self.last_batch_iteration}
  489. def load_state_dict(self, sd):
  490. self.last_batch_iteration = sd['last_batch_iteration']
  491. class WarmupLR(object):
  492. """Increase the learning rate of each parameter group from min lr to max lr
  493. over warmup_num_steps steps, and then fix at max lr.
  494. Args:
  495. optimizer (Optimizer): Wrapped optimizer.
  496. warmup_min_lr (float or list): minimum learning rate. Default: 0
  497. warmup_max_lr (float or list): maximum learning rate. Default: 0.001
  498. warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
  499. warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
  500. last_batch_iteration (int): The index of the last batch. Default: -1.
  501. Example:
  502. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  503. >>> scheduler = WarmupLR(optimizer)
  504. >>> data_loader = torch.utils.data.DataLoader(...)
  505. >>> for epoch in range(10):
  506. >>> for batch in data_loader:
  507. >>> train_batch(...)
  508. >>> scheduler.step()
  509. """
  510. def __init__(self,
  511. optimizer: Optimizer,
  512. warmup_min_lr: float = 0.0,
  513. warmup_max_lr: float = 0.001,
  514. warmup_num_steps: int = 1000,
  515. warmup_type: str = WARMUP_LOG_RATE,
  516. last_batch_iteration: int = -1):
  517. self.optimizer = get_torch_optimizer(optimizer)
  518. self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr")
  519. self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr")
  520. self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)]
  521. self.warmup_num_steps = max(2, warmup_num_steps)
  522. # Currently only support linear and log function
  523. if warmup_type not in {WARMUP_LOG_RATE, WARMUP_LINEAR_RATE}:
  524. logger.warning(f"Using unknown warmup_type: {warmup_type}. The increasing function "
  525. f"is set to default (log)")
  526. warmup_type = WARMUP_LOG_RATE
  527. self.warmup_type = warmup_type
  528. self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
  529. self.last_batch_iteration = last_batch_iteration
  530. def get_lr(self):
  531. if self.last_batch_iteration < 0:
  532. logger.warning("Attempting to get learning rate from scheduler before it has started")
  533. return [0.0]
  534. gamma = self._get_gamma()
  535. return [min_lr + (delta_lr * gamma) for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)]
  536. def get_last_lr(self):
  537. """ Return last computed learning rate by current scheduler.
  538. """
  539. assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
  540. return self._last_lr
  541. def step(self, last_batch_iteration=None):
  542. if last_batch_iteration is None:
  543. last_batch_iteration = self.last_batch_iteration + 1
  544. self.last_batch_iteration = last_batch_iteration
  545. for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
  546. param_group['lr'] = lr
  547. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  548. def state_dict(self):
  549. return {'last_batch_iteration': self.last_batch_iteration}
  550. def load_state_dict(self, sd):
  551. self.last_batch_iteration = sd['last_batch_iteration']
  552. def _get_gamma(self):
  553. if self.last_batch_iteration < self.warmup_num_steps:
  554. if self.warmup_type == WARMUP_LOG_RATE:
  555. return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
  556. elif self.warmup_type == WARMUP_LINEAR_RATE:
  557. return self.last_batch_iteration / self.warmup_num_steps
  558. return 1.0
  559. def _format_param(self, optimizer, param_value, param_name):
  560. if isinstance(param_value, list) or isinstance(param_value, tuple):
  561. if len(param_value) != len(optimizer.param_groups):
  562. raise ValueError("expected {} value for {}, got {}".format(len(optimizer.param_groups), param_name,
  563. FileNotFoundError(param_value)))
  564. return list(param_value)
  565. return [param_value] * len(optimizer.param_groups)
  566. class WarmupDecayLR(WarmupLR):
  567. """Increase the learning rate of each parameter group from min lr to max lr
  568. over warmup_num_steps steps, and then decay at linear rate over the remaining training steps.
  569. Args:
  570. optimizer (Optimizer): Wrapped optimizer.
  571. total_num_steps (int): total number of training steps
  572. warmup_min_lr (float or list): minimum learning rate. Default: 0
  573. warmup_max_lr (float or list): maximum learning rate. Default: 0.001
  574. warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
  575. warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
  576. last_batch_iteration (int): The index of the last batch. Default: -1.
  577. Example:
  578. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  579. >>> scheduler = WarmupDecayLR(optimizer, 1000000)
  580. >>> data_loader = torch.utils.data.DataLoader(...)
  581. >>> for epoch in range(10):
  582. >>> for batch in data_loader:
  583. >>> train_batch(...)
  584. >>> scheduler.step()
  585. """
  586. def __init__(self,
  587. optimizer: Optimizer,
  588. total_num_steps: int,
  589. warmup_min_lr: float = 0.0,
  590. warmup_max_lr: float = 0.001,
  591. warmup_num_steps: int = 1000,
  592. warmup_type: str = WARMUP_LOG_RATE,
  593. last_batch_iteration: int = -1):
  594. self.total_num_steps = total_num_steps
  595. super(WarmupDecayLR, self).__init__(optimizer, warmup_min_lr, warmup_max_lr, warmup_num_steps, warmup_type,
  596. last_batch_iteration)
  597. if self.total_num_steps < self.warmup_num_steps:
  598. logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format(
  599. total_num_steps, warmup_num_steps))
  600. def _get_gamma(self):
  601. if self.last_batch_iteration < self.warmup_num_steps:
  602. if self.warmup_type == WARMUP_LOG_RATE:
  603. return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
  604. elif self.warmup_type == WARMUP_LINEAR_RATE:
  605. return self.last_batch_iteration / self.warmup_num_steps
  606. return max(
  607. 0.0,
  608. float(self.total_num_steps - self.last_batch_iteration) /
  609. float(max(1.0, self.total_num_steps - self.warmup_num_steps)))