lr_schedules.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854
  1. """
  2. Copyright 2019 The Microsoft DeepSpeed Team
  3. Implementation of learning rate schedules.
  4. Taken and modified from PyTorch v1.0.1 source
  5. https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py
  6. """
  7. import argparse
  8. from torch.optim import Optimizer
  9. import math
  10. from deepspeed.utils import logger
  11. LR_SCHEDULE = 'lr_schedule'
  12. LR_RANGE_TEST = 'LRRangeTest'
  13. ONE_CYCLE = 'OneCycle'
  14. WARMUP_LR = 'WarmupLR'
  15. WARMUP_DECAY_LR = 'WarmupDecayLR'
  16. VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR]
  17. LR_RANGE_TEST_MIN_LR = 'lr_range_test_min_lr'
  18. LR_RANGE_TEST_STEP_RATE = 'lr_range_test_step_rate'
  19. LR_RANGE_TEST_STEP_SIZE = 'lr_range_test_step_size'
  20. LR_RANGE_TEST_STAIRCASE = 'lr_range_test_staircase'
  21. EDGE_VALUE = 'edge_value'
  22. MID_VALUE = 'mid_value'
  23. CYCLE_FIRST_STEP_SIZE = 'cycle_first_step_size'
  24. CYCLE_FIRST_STAIR_COUNT = 'cycle_first_stair_count'
  25. CYCLE_SECOND_STEP_SIZE = 'cycle_second_step_size'
  26. CYCLE_SECOND_STAIR_COUNT = 'cycle_second_stair_count'
  27. DECAY_STEP_SIZE = 'decay_step_size'
  28. CYCLE_MIN_LR = 'cycle_min_lr'
  29. CYCLE_MAX_LR = 'cycle_max_lr'
  30. DECAY_LR_RATE = 'decay_lr_rate'
  31. CYCLE_MIN_MOM = 'cycle_min_mom'
  32. CYCLE_MAX_MOM = 'cycle_max_mom'
  33. DECAY_MOM_RATE = 'decay_mom_rate'
  34. WARMUP_MIN_LR = 'warmup_min_lr'
  35. WARMUP_MAX_LR = 'warmup_max_lr'
  36. WARMUP_NUM_STEPS = 'warmup_num_steps'
  37. WARMUP_TYPE = 'warmup_type'
  38. WARMUP_LOG_RATE = 'log'
  39. WARMUP_LINEAR_RATE = 'linear'
  40. TOTAL_NUM_STEPS = 'total_num_steps'
  41. def add_tuning_arguments(parser):
  42. group = parser.add_argument_group('Convergence Tuning',
  43. 'Convergence tuning configurations')
  44. # LR scheduler
  45. group.add_argument('--lr_schedule',
  46. type=str,
  47. default=None,
  48. help='LR schedule for training.')
  49. # Learning rate range test
  50. group.add_argument("--lr_range_test_min_lr",
  51. type=float,
  52. default=0.001,
  53. help='Starting lr value.')
  54. group.add_argument("--lr_range_test_step_rate",
  55. type=float,
  56. default=1.0,
  57. help='scaling rate for LR range test.')
  58. group.add_argument("--lr_range_test_step_size",
  59. type=int,
  60. default=1000,
  61. help='training steps per LR change.')
  62. group.add_argument("--lr_range_test_staircase",
  63. type=bool,
  64. default=False,
  65. help='use staircase scaling for LR range test.')
  66. # OneCycle schedule
  67. group.add_argument("--cycle_first_step_size",
  68. type=int,
  69. default=1000,
  70. help='size of first step of 1Cycle schedule (training steps).')
  71. group.add_argument("--cycle_first_stair_count",
  72. type=int,
  73. default=-1,
  74. help='first stair count for 1Cycle schedule.')
  75. group.add_argument(
  76. "--cycle_second_step_size",
  77. type=int,
  78. default=-1,
  79. help='size of second step of 1Cycle schedule (default first_step_size).')
  80. group.add_argument("--cycle_second_stair_count",
  81. type=int,
  82. default=-1,
  83. help='second stair count for 1Cycle schedule.')
  84. group.add_argument(
  85. "--decay_step_size",
  86. type=int,
  87. default=1000,
  88. help='size of intervals for applying post cycle decay (training steps).')
  89. # 1Cycle LR
  90. group.add_argument("--cycle_min_lr",
  91. type=float,
  92. default=0.01,
  93. help='1Cycle LR lower bound.')
  94. group.add_argument("--cycle_max_lr",
  95. type=float,
  96. default=0.1,
  97. help='1Cycle LR upper bound.')
  98. group.add_argument("--decay_lr_rate",
  99. type=float,
  100. default=0.0,
  101. help='post cycle LR decay rate.')
  102. # 1Cycle Momentum
  103. group.add_argument('--cycle_momentum',
  104. default=False,
  105. action='store_true',
  106. help='Enable 1Cycle momentum schedule.')
  107. group.add_argument("--cycle_min_mom",
  108. type=float,
  109. default=0.8,
  110. help='1Cycle momentum lower bound.')
  111. group.add_argument("--cycle_max_mom",
  112. type=float,
  113. default=0.9,
  114. help='1Cycle momentum upper bound.')
  115. group.add_argument("--decay_mom_rate",
  116. type=float,
  117. default=0.0,
  118. help='post cycle momentum decay rate.')
  119. # Warmup LR
  120. group.add_argument('--warmup_min_lr',
  121. type=float,
  122. default=0,
  123. help='WarmupLR minimum/initial LR value')
  124. group.add_argument('--warmup_max_lr',
  125. type=float,
  126. default=0.001,
  127. help='WarmupLR maximum LR value.')
  128. group.add_argument('--warmup_num_steps',
  129. type=int,
  130. default=1000,
  131. help='WarmupLR step count for LR warmup.')
  132. group.add_argument('--warmup_type',
  133. type=str,
  134. default=WARMUP_LOG_RATE,
  135. help='WarmupLR increasing function during warmup')
  136. return parser
  137. def parse_arguments():
  138. parser = argparse.ArgumentParser()
  139. parser = add_tuning_arguments(parser)
  140. lr_sched_args, unknown_args = parser.parse_known_args()
  141. return lr_sched_args, unknown_args
  142. def override_lr_range_test_params(args, params):
  143. if hasattr(args, LR_RANGE_TEST_MIN_LR) and args.lr_range_test_min_lr is not None:
  144. params[LR_RANGE_TEST_MIN_LR] = args.lr_range_test_min_lr
  145. if hasattr(args,
  146. LR_RANGE_TEST_STEP_RATE) and args.lr_range_test_step_rate is not None:
  147. params[LR_RANGE_TEST_STEP_RATE] = args.lr_range_test_step_rate
  148. if hasattr(args,
  149. LR_RANGE_TEST_STEP_SIZE) and args.lr_range_test_step_size is not None:
  150. params[LR_RANGE_TEST_STEP_SIZE] = args.lr_range_test_step_size
  151. if hasattr(args,
  152. LR_RANGE_TEST_STAIRCASE) and args.lr_range_test_staircase is not None:
  153. params[LR_RANGE_TEST_STAIRCASE] = args.lr_range_test_staircase
  154. def override_1cycle_params(args, params):
  155. if hasattr(args, CYCLE_FIRST_STEP_SIZE) and args.cycle_first_step_size is not None:
  156. params[CYCLE_FIRST_STEP_SIZE] = args.cycle_first_step_size
  157. if hasattr(args,
  158. CYCLE_FIRST_STAIR_COUNT) and args.cycle_first_stair_count is not None:
  159. params[CYCLE_FIRST_STAIR_COUNT] = args.cycle_first_stair_count
  160. if hasattr(args, CYCLE_SECOND_STEP_SIZE) and args.cycle_second_step_size is not None:
  161. params[CYCLE_SECOND_STEP_SIZE] = args.cycle_second_step_size
  162. if hasattr(args,
  163. CYCLE_SECOND_STAIR_COUNT) and args.cycle_second_stair_count is not None:
  164. params[CYCLE_SECOND_STAIR_COUNT] = args.cycle_second_stair_count
  165. if hasattr(args, DECAY_STEP_SIZE) and args.decay_step_size is not None:
  166. params[DECAY_STEP_SIZE] = args.decay_step_size
  167. # 1Cycle LR params
  168. if hasattr(args, CYCLE_MIN_LR) and args.cycle_min_lr is not None:
  169. params[CYCLE_MIN_LR] = args.cycle_min_lr
  170. if hasattr(args, CYCLE_MAX_LR) and args.cycle_max_lr is not None:
  171. params[CYCLE_MAX_LR] = args.cycle_max_lr
  172. if hasattr(args, DECAY_LR_RATE) and args.decay_lr_rate is not None:
  173. params[DECAY_LR_RATE] = args.decay_lr_rate
  174. # 1Cycle MOM params
  175. if hasattr(args, CYCLE_MIN_MOM) and args.cycle_min_mom is not None:
  176. params[CYCLE_MIN_MOM] = args.cycle_min_mom
  177. if hasattr(args, CYCLE_MAX_MOM) and args.cycle_max_mom is not None:
  178. params[CYCLE_MAX_MOM] = args.cycle_max_mom
  179. if hasattr(args, DECAY_MOM_RATE) and args.decay_mom_rate is not None:
  180. params[DECAY_MOM_RATE] = args.decay_mom_rate
  181. def override_warmupLR_params(args, params):
  182. if hasattr(args, WARMUP_MIN_LR) and args.warmup_min_lr is not None:
  183. params[WARMUP_MIN_LR] = args.warmup_min_lr
  184. if hasattr(args, WARMUP_MAX_LR) and args.warmup_max_lr is not None:
  185. params[WARMUP_MAX_LR] = args.warmup_max_lr
  186. if hasattr(args, WARMUP_NUM_STEPS) and args.warmup_num_steps is not None:
  187. params[WARMUP_NUM_STEPS] = args.warmup_num_steps
  188. if hasattr(args, WARMUP_TYPE) and args.warmup_type is not None:
  189. params[WARMUP_TYPE] = args.warmup_type
  190. def override_params(args, params):
  191. # LR range test params
  192. override_lr_range_test_params(args, params)
  193. # 1Cycle params
  194. override_1cycle_params(args, params)
  195. # WarmupLR params
  196. override_warmupLR_params(args, params)
  197. def get_config_from_args(args):
  198. if not hasattr(args, LR_SCHEDULE) or args.lr_schedule is None:
  199. return None, '--{} not specified on command line'.format(LR_SCHEDULE)
  200. if not args.lr_schedule in VALID_LR_SCHEDULES:
  201. return None, '{} is not supported LR schedule'.format(args.lr_schedule)
  202. config = {}
  203. config['type'] = args.lr_schedule
  204. config['params'] = {}
  205. if args.lr_schedule == LR_RANGE_TEST:
  206. override_lr_range_test_params(args, config['params'])
  207. elif args.lr_schedule == ONE_CYCLE:
  208. override_1cycle_params(args, config['params'])
  209. else:
  210. override_warmupLR_params(args, config['params'])
  211. return config, None
  212. def get_lr_from_config(config):
  213. if not 'type' in config:
  214. return None, 'LR schedule type not defined in config'
  215. if not 'params' in config:
  216. return None, 'LR schedule params not defined in config'
  217. lr_schedule = config['type']
  218. lr_params = config['params']
  219. if not lr_schedule in VALID_LR_SCHEDULES:
  220. return None, '{} is not a valid LR schedule'.format(lr_schedule)
  221. if lr_schedule == LR_RANGE_TEST:
  222. return lr_params[LR_RANGE_TEST_MIN_LR], ''
  223. if lr_schedule == ONE_CYCLE:
  224. return lr_params[CYCLE_MAX_LR], ''
  225. # Warmup LR
  226. return lr_params[WARMUP_MAX_LR], ''
  227. """
  228. Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped
  229. optimizer to see if requirement is satisfied.
  230. TODO: Looking under the hood to examine the wrapped optimizer is a hack that requires a better long-term fix.
  231. """
  232. def get_torch_optimizer(optimizer):
  233. if isinstance(optimizer, Optimizer):
  234. return optimizer
  235. if hasattr(optimizer, 'optimizer') and isinstance(optimizer.optimizer, Optimizer):
  236. return optimizer.optimizer
  237. raise TypeError('{} is not a subclass of torch.optim.Optimizer'.format(
  238. type(optimizer).__name__))
  239. class LRRangeTest(object):
  240. """Sets the learning rate of each parameter group according to
  241. learning rate range test (LRRT) policy. The policy increases learning
  242. rate starting from a base value with a constant frequency, as detailed in
  243. the paper `A disciplined approach to neural network hyper-parameters: Part1`_.
  244. LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to
  245. configure the LR boundaries for Cyclic LR schedules.
  246. LRRT changes the learning rate after every batch.
  247. `step` should be called after a batch has been used for training.
  248. Args:
  249. optimizer (Optimizer): Wrapped optimizer.
  250. lr_range_test_min_lr (float or list): Initial learning rate which is the
  251. lower boundary in the range test for each parameter group.
  252. lr_range_test_step_size (int): Interval of training steps to increase learning rate. Default: 2000
  253. lr_range_test_step_rate (float): Scaling rate for range test. Default: 1.0
  254. lr_range_test_staircase (bool): Scale in staircase fashion, rather than continuous. Default: False.
  255. last_batch_iteration (int): The index of the last batch. This parameter is used when
  256. resuming a training job. Since `step()` should be invoked after each
  257. batch instead of after each epoch, this number represents the total
  258. number of *batches* computed, not the total number of epochs computed.
  259. When last_batch_iteration=-1, the schedule is started from the beginning.
  260. Default: -1
  261. Example:
  262. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  263. >>> scheduler = LRRangeTest(optimizer)
  264. >>> data_loader = torch.utils.data.DataLoader(...)
  265. >>> for epoch in range(10):
  266. >>> for batch in data_loader:
  267. >>> train_batch(...)
  268. >>> scheduler.step()
  269. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay:
  270. https://arxiv.org/abs/1803.09820
  271. """
  272. def __init__(self,
  273. optimizer: Optimizer,
  274. lr_range_test_min_lr: float = 1e-3,
  275. lr_range_test_step_size: int = 2000,
  276. lr_range_test_step_rate: float = 1.0,
  277. lr_range_test_staircase: bool = False,
  278. last_batch_iteration: int = -1):
  279. self.optimizer = get_torch_optimizer(optimizer)
  280. if isinstance(lr_range_test_min_lr,
  281. list) or isinstance(lr_range_test_min_lr,
  282. tuple):
  283. if len(lr_range_test_min_lr) != len(self.optimizer.param_groups):
  284. raise ValueError("expected {} lr_range_test_min_lr, got {}".format(
  285. len(self.optimizer.param_groups),
  286. len(lr_range_test_min_lr)))
  287. self.min_lr = list(lr_range_test_min_lr)
  288. else:
  289. self.min_lr = [lr_range_test_min_lr] * len(self.optimizer.param_groups)
  290. self.step_size = lr_range_test_step_size
  291. self.step_rate = lr_range_test_step_rate
  292. self.last_batch_iteration = last_batch_iteration
  293. self.staircase = lr_range_test_staircase
  294. self.interval_fn = self._staircase_interval if lr_range_test_staircase else self._continuous_interval
  295. if last_batch_iteration == -1:
  296. self._update_optimizer(self.min_lr)
  297. def _staircase_interval(self):
  298. return math.floor(float(self.last_batch_iteration + 1) / self.step_size)
  299. def _continuous_interval(self):
  300. return float(self.last_batch_iteration + 1) / self.step_size
  301. def _get_increase(self):
  302. return (1 + self.step_rate * self.interval_fn())
  303. def get_lr(self):
  304. lr_increase = self._get_increase()
  305. return [
  306. lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr
  307. ]
  308. def get_last_lr(self):
  309. """ Return last computed learning rate by current scheduler.
  310. """
  311. assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
  312. return self._last_lr
  313. def _update_optimizer(self, group_lrs):
  314. for param_group, lr in zip(self.optimizer.param_groups, group_lrs):
  315. param_group['lr'] = lr
  316. def step(self, batch_iteration=None):
  317. if batch_iteration is None:
  318. batch_iteration = self.last_batch_iteration + 1
  319. self.last_batch_iteration = batch_iteration
  320. self._update_optimizer(self.get_lr())
  321. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  322. def state_dict(self):
  323. return {'last_batch_iteration': self.last_batch_iteration}
  324. def load_state_dict(self, sd):
  325. self.last_batch_iteration = sd['last_batch_iteration']
  326. class OneCycle(object):
  327. """Sets the learning rate of each parameter group according to
  328. 1Cycle learning rate policy (1CLR). 1CLR is a variation of the
  329. Cyclical Learning Rate (CLR) policy that involves one cycle followed by
  330. decay. The policy simultaneously cycles the learning rate (and momentum)
  331. between two boundaries with a constant frequency, as detailed in
  332. the paper `A disciplined approach to neural network hyper-parameters`_.
  333. 1CLR policy changes the learning rate after every batch.
  334. `step` should be called after a batch has been used for training.
  335. This implementation was adapted from the github repo: `pytorch/pytorch`_
  336. Args:
  337. optimizer (Optimizer): Wrapped optimizer.
  338. cycle_min_lr (float or list): Initial learning rate which is the
  339. lower boundary in the cycle for each parameter group.
  340. cycle_max_lr (float or list): Upper learning rate boundaries in the cycle
  341. for each parameter group. Functionally,
  342. it defines the cycle amplitude (cycle_max_lr - cycle_min_lr).
  343. The lr at any cycle is the sum of cycle_min_lr
  344. and some scaling of the amplitude; therefore
  345. cycle_max_lr may not actually be reached depending on
  346. scaling function.
  347. decay_lr_rate(float): Decay rate for learning rate. Default: 0.
  348. cycle_first_step_size (int): Number of training iterations in the
  349. increasing half of a cycle. Default: 2000
  350. cycle_second_step_size (int): Number of training iterations in the
  351. decreasing half of a cycle. If cycle_second_step_size is None,
  352. it is set to cycle_first_step_size. Default: None
  353. cycle_first_stair_count(int): Number of stairs in first half of cycle phase. This means
  354. lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
  355. cycle_second_stair_count(int): Number of stairs in second half of cycle phase. This means
  356. lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
  357. decay_step_size (int): Intervals for applying decay in decay phase. Default: 0, means no decay.
  358. cycle_momentum (bool): If ``True``, momentum is cycled inversely
  359. to learning rate between 'cycle_min_mom' and 'cycle_max_mom'.
  360. Default: True
  361. cycle_min_mom (float or list): Initial momentum which is the
  362. lower boundary in the cycle for each parameter group.
  363. Default: 0.8
  364. cycle_max_mom (float or list): Upper momentum boundaries in the cycle
  365. for each parameter group. Functionally,
  366. it defines the cycle amplitude (cycle_max_mom - cycle_min_mom).
  367. The momentum at any cycle is the difference of cycle_max_mom
  368. and some scaling of the amplitude; therefore
  369. cycle_min_mom may not actually be reached depending on
  370. scaling function. Default: 0.9
  371. decay_mom_rate (float): Decay rate for momentum. Default: 0.
  372. last_batch_iteration (int): The index of the last batch. This parameter is used when
  373. resuming a training job. Since `step()` should be invoked after each
  374. batch instead of after each epoch, this number represents the total
  375. number of *batches* computed, not the total number of epochs computed.
  376. When last_batch_iteration=-1, the schedule is started from the beginning.
  377. Default: -1
  378. Example:
  379. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  380. >>> scheduler = OneCycle(optimizer, 0.0001, 0.0010)
  381. >>> data_loader = torch.utils.data.DataLoader(...)
  382. >>> for epoch in range(10):
  383. >>> for batch in data_loader:
  384. >>> train_batch(...)
  385. >>> scheduler.step()
  386. .. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820
  387. """
  388. def __init__(self,
  389. optimizer,
  390. cycle_min_lr,
  391. cycle_max_lr,
  392. decay_lr_rate=0.,
  393. cycle_first_step_size=2000,
  394. cycle_second_step_size=None,
  395. cycle_first_stair_count=0,
  396. cycle_second_stair_count=None,
  397. decay_step_size=0,
  398. cycle_momentum=True,
  399. cycle_min_mom=0.8,
  400. cycle_max_mom=0.9,
  401. decay_mom_rate=0.,
  402. last_batch_iteration=-1):
  403. self.optimizer = get_torch_optimizer(optimizer)
  404. # Initialize cycle shape
  405. self._initialize_cycle(cycle_first_step_size,
  406. cycle_second_step_size,
  407. cycle_first_stair_count,
  408. cycle_second_stair_count,
  409. decay_step_size)
  410. # Initialize cycle lr
  411. self._initialize_lr(self.optimizer,
  412. cycle_min_lr,
  413. cycle_max_lr,
  414. decay_lr_rate,
  415. last_batch_iteration)
  416. # Initialize cyclic momentum
  417. self.cycle_momentum = cycle_momentum
  418. if cycle_momentum:
  419. self._initialize_momentum(self.optimizer,
  420. cycle_min_mom,
  421. cycle_max_mom,
  422. decay_mom_rate,
  423. last_batch_iteration)
  424. # Initialize batch iteration tracker
  425. self.last_batch_iteration = last_batch_iteration
  426. # Configure cycle shape
  427. def _initialize_cycle(self,
  428. cycle_first_step_size,
  429. cycle_second_step_size,
  430. cycle_first_stair_count,
  431. cycle_second_stair_count,
  432. decay_step_size):
  433. cycle_first_step_size = float(cycle_first_step_size)
  434. cycle_second_step_size = float(
  435. cycle_second_step_size
  436. ) if cycle_second_step_size is not None else cycle_first_step_size
  437. self.total_size = cycle_first_step_size + cycle_second_step_size
  438. self.step_ratio = cycle_first_step_size / self.total_size
  439. self.first_stair_count = cycle_first_stair_count
  440. self.second_stair_count = cycle_first_stair_count if cycle_second_stair_count is None else cycle_second_stair_count
  441. self.decay_step_size = decay_step_size
  442. if math.isclose(self.decay_step_size, 0):
  443. self.skip_lr_decay = True
  444. self.skip_mom_decay = True
  445. else:
  446. self.skip_lr_decay = False
  447. self.skip_mom_decay = False
  448. # Configure lr schedule
  449. def _initialize_lr(self,
  450. optimizer,
  451. cycle_min_lr,
  452. cycle_max_lr,
  453. decay_lr_rate,
  454. last_batch_iteration):
  455. self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups)
  456. if last_batch_iteration == -1:
  457. for lr, group in zip(self.min_lrs, optimizer.param_groups):
  458. group['lr'] = lr
  459. self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups)
  460. self.decay_lr_rate = decay_lr_rate
  461. if math.isclose(self.decay_lr_rate, 0):
  462. self.skip_lr_decay = True
  463. # Configure momentum schedule
  464. def _initialize_momentum(self,
  465. optimizer,
  466. cycle_min_mom,
  467. cycle_max_mom,
  468. decay_mom_rate,
  469. last_batch_iteration):
  470. if 'betas' not in optimizer.defaults:
  471. optimizer_name = type(optimizer).__name__
  472. logger.warn(
  473. f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
  474. )
  475. self.cycle_momentum = False
  476. return
  477. self.decay_mom_rate = decay_mom_rate
  478. self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups)
  479. self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups)
  480. if last_batch_iteration == -1:
  481. for momentum, group in zip(self.min_moms, optimizer.param_groups):
  482. group['betas'] = momentum
  483. if math.isclose(self.decay_mom_rate, 0):
  484. self.skip_mom_decay = True
  485. def _get_scale_factor(self):
  486. batch_iteration = (self.last_batch_iteration + 1)
  487. cycle = math.floor(1 + batch_iteration / self.total_size)
  488. x = 1. + batch_iteration / self.total_size - cycle
  489. if x <= self.step_ratio:
  490. scale_factor = x / self.step_ratio
  491. else:
  492. scale_factor = (x - 1) / (self.step_ratio - 1)
  493. return scale_factor
  494. def _get_cycle_mom(self):
  495. scale_factor = self._get_scale_factor()
  496. momentums = []
  497. for base_betas, max_betas in zip(self.min_moms, self.max_moms):
  498. cycle_min_mom = base_betas[0]
  499. cycle_max_mom = max_betas[0]
  500. base_height = (cycle_max_mom - cycle_min_mom) * scale_factor
  501. momentum = cycle_max_mom - base_height
  502. momentums.append((momentum, base_betas[1]))
  503. return momentums
  504. def _get_cycle_lr(self):
  505. scale_factor = self._get_scale_factor()
  506. lrs = []
  507. for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
  508. base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
  509. lr = cycle_min_lr + base_height
  510. lrs.append(lr)
  511. return lrs
  512. def _get_decay_mom(self, decay_batch_iteration):
  513. if self.skip_mom_decay:
  514. return self.max_moms
  515. decay_interval = decay_batch_iteration / self.decay_step_size
  516. mom_decay_factor = (1 + self.decay_mom_rate * decay_interval)
  517. momentums = [(beta0 * mom_decay_factor, beta1) for beta0, beta1 in self.max_moms]
  518. return momentums
  519. def _get_decay_lr(self, decay_batch_iteration):
  520. """Calculates the learning rate at batch index. This function is used
  521. after the cycle completes and post cycle decaying of lr/mom is enabled.
  522. This function treats `self.last_batch_iteration` as the last batch index.
  523. """
  524. if self.skip_lr_decay:
  525. return self.min_lrs
  526. decay_interval = decay_batch_iteration / self.decay_step_size
  527. lr_decay_factor = (1 + self.decay_lr_rate * decay_interval)
  528. lrs = [cycle_min_lr / lr_decay_factor for cycle_min_lr in self.min_lrs]
  529. return lrs
  530. def get_lr(self):
  531. """Calculates the learning rate at batch index. This function treats
  532. `self.last_batch_iteration` as the last batch index.
  533. """
  534. if self.last_batch_iteration < self.total_size:
  535. return self._get_cycle_lr()
  536. return self._get_decay_lr(self.last_batch_iteration - self.total_size + 1)
  537. def get_mom(self):
  538. """Calculates the momentum at batch index. This function treats
  539. `self.last_batch_iteration` as the last batch index.
  540. """
  541. if not self.cycle_momentum:
  542. return None
  543. if self.last_batch_iteration < self.total_size:
  544. return self._get_cycle_mom()
  545. return self._get_decay_mom(self.last_batch_iteration - self.total_size + 1)
  546. def get_last_lr(self):
  547. """ Return last computed learning rate by current scheduler.
  548. """
  549. assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
  550. return self._last_lr
  551. def step(self, batch_iteration=None):
  552. """ Updates the optimizer with the learning rate for the last batch index.
  553. `self.last_batch_iteration` is treated as the last batch index.
  554. If self.cycle_momentum is true, also updates optimizer momentum.
  555. """
  556. if batch_iteration is None:
  557. batch_iteration = self.last_batch_iteration + 1
  558. self.last_batch_iteration = batch_iteration
  559. for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
  560. param_group['lr'] = lr
  561. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  562. if self.cycle_momentum:
  563. momentums = self.get_mom()
  564. for param_group, momentum in zip(self.optimizer.param_groups, momentums):
  565. param_group['betas'] = momentum
  566. def state_dict(self):
  567. return {'last_batch_iteration': self.last_batch_iteration}
  568. def load_state_dict(self, sd):
  569. self.last_batch_iteration = sd['last_batch_iteration']
  570. class WarmupLR(object):
  571. """Increase the learning rate of each parameter group from min lr to max lr
  572. over warmup_num_steps steps, and then fix at max lr.
  573. Args:
  574. optimizer (Optimizer): Wrapped optimizer.
  575. warmup_min_lr (float or list): minimum learning rate. Default: 0
  576. warmup_max_lr (float or list): maximum learning rate. Default: 0.001
  577. warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
  578. warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
  579. last_batch_iteration (int): The index of the last batch. Default: -1.
  580. Example:
  581. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  582. >>> scheduler = WarmupLR(optimizer)
  583. >>> data_loader = torch.utils.data.DataLoader(...)
  584. >>> for epoch in range(10):
  585. >>> for batch in data_loader:
  586. >>> train_batch(...)
  587. >>> scheduler.step()
  588. """
  589. def __init__(self,
  590. optimizer: Optimizer,
  591. warmup_min_lr: float = 0.0,
  592. warmup_max_lr: float = 0.001,
  593. warmup_num_steps: int = 1000,
  594. warmup_type: str = WARMUP_LOG_RATE,
  595. last_batch_iteration: int = -1):
  596. self.optimizer = get_torch_optimizer(optimizer)
  597. self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr")
  598. self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr")
  599. self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)]
  600. self.warmup_num_steps = max(2, warmup_num_steps)
  601. # Currently only support linear and log function
  602. if warmup_type not in {WARMUP_LOG_RATE, WARMUP_LINEAR_RATE}:
  603. logger.warning(
  604. f"Using unknown warmup_type: {warmup_type}. The increasing function "
  605. f"is set to default (log)")
  606. warmup_type = WARMUP_LOG_RATE
  607. self.warmup_type = warmup_type
  608. self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
  609. self.last_batch_iteration = last_batch_iteration
  610. def get_lr(self):
  611. if self.last_batch_iteration < 0:
  612. logger.warning(
  613. "Attempting to get learning rate from scheduler before it has started")
  614. return [0.0]
  615. gamma = self._get_gamma()
  616. return [
  617. min_lr + (delta_lr * gamma) for min_lr,
  618. delta_lr in zip(self.min_lrs,
  619. self.delta_lrs)
  620. ]
  621. def get_last_lr(self):
  622. """ Return last computed learning rate by current scheduler.
  623. """
  624. assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
  625. return self._last_lr
  626. def step(self, last_batch_iteration=None):
  627. if last_batch_iteration is None:
  628. last_batch_iteration = self.last_batch_iteration + 1
  629. self.last_batch_iteration = last_batch_iteration
  630. for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
  631. param_group['lr'] = lr
  632. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  633. def state_dict(self):
  634. return {'last_batch_iteration': self.last_batch_iteration}
  635. def load_state_dict(self, sd):
  636. self.last_batch_iteration = sd['last_batch_iteration']
  637. def _get_gamma(self):
  638. if self.last_batch_iteration < self.warmup_num_steps:
  639. if self.warmup_type == WARMUP_LOG_RATE:
  640. return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
  641. elif self.warmup_type == WARMUP_LINEAR_RATE:
  642. return self.last_batch_iteration / self.warmup_num_steps
  643. return 1.0
  644. def _format_param(self, optimizer, param_value, param_name):
  645. if isinstance(param_value, list) or isinstance(param_value, tuple):
  646. if len(param_value) != len(optimizer.param_groups):
  647. raise ValueError("expected {} value for {}, got {}".format(
  648. len(optimizer.param_groups),
  649. param_name,
  650. FileNotFoundError(param_value)))
  651. return list(param_value)
  652. return [param_value] * len(optimizer.param_groups)
  653. class WarmupDecayLR(WarmupLR):
  654. """Increase the learning rate of each parameter group from min lr to max lr
  655. over warmup_num_steps steps, and then decay at linear rate over the remaining training steps.
  656. Args:
  657. optimizer (Optimizer): Wrapped optimizer.
  658. total_num_steps (int): total number of training steps
  659. warmup_min_lr (float or list): minimum learning rate. Default: 0
  660. warmup_max_lr (float or list): maximum learning rate. Default: 0.001
  661. warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
  662. warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
  663. last_batch_iteration (int): The index of the last batch. Default: -1.
  664. Example:
  665. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  666. >>> scheduler = WarmupDecayLR(optimizer, 1000000)
  667. >>> data_loader = torch.utils.data.DataLoader(...)
  668. >>> for epoch in range(10):
  669. >>> for batch in data_loader:
  670. >>> train_batch(...)
  671. >>> scheduler.step()
  672. """
  673. def __init__(self,
  674. optimizer: Optimizer,
  675. total_num_steps: int,
  676. warmup_min_lr: float = 0.0,
  677. warmup_max_lr: float = 0.001,
  678. warmup_num_steps: int = 1000,
  679. warmup_type: str = WARMUP_LOG_RATE,
  680. last_batch_iteration: int = -1):
  681. self.total_num_steps = total_num_steps
  682. super(WarmupDecayLR,
  683. self).__init__(optimizer,
  684. warmup_min_lr,
  685. warmup_max_lr,
  686. warmup_num_steps,
  687. warmup_type,
  688. last_batch_iteration)
  689. if self.total_num_steps < self.warmup_num_steps:
  690. logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format(
  691. total_num_steps,
  692. warmup_num_steps))
  693. def _get_gamma(self):
  694. if self.last_batch_iteration < self.warmup_num_steps:
  695. if self.warmup_type == WARMUP_LOG_RATE:
  696. return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
  697. elif self.warmup_type == WARMUP_LINEAR_RATE:
  698. return self.last_batch_iteration / self.warmup_num_steps
  699. return max(
  700. 0.0,
  701. float(self.total_num_steps - self.last_batch_iteration) /
  702. float(max(1.0,
  703. self.total_num_steps - self.warmup_num_steps)))