123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- """
- Implementation of learning rate schedules.
- Taken and modified from PyTorch v1.0.1 source
- https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py
- """
- import argparse
- from torch.optim import Optimizer
- import math
- from deepspeed.utils import logger
- LR_SCHEDULE = 'lr_schedule'
- LR_RANGE_TEST = 'LRRangeTest'
- ONE_CYCLE = 'OneCycle'
- WARMUP_LR = 'WarmupLR'
- WARMUP_DECAY_LR = 'WarmupDecayLR'
- WARMUP_COSINE_LR = 'WarmupCosineLR'
- VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR, WARMUP_COSINE_LR]
- LR_RANGE_TEST_MIN_LR = 'lr_range_test_min_lr'
- LR_RANGE_TEST_STEP_RATE = 'lr_range_test_step_rate'
- LR_RANGE_TEST_STEP_SIZE = 'lr_range_test_step_size'
- LR_RANGE_TEST_STAIRCASE = 'lr_range_test_staircase'
- EDGE_VALUE = 'edge_value'
- MID_VALUE = 'mid_value'
- CYCLE_FIRST_STEP_SIZE = 'cycle_first_step_size'
- CYCLE_FIRST_STAIR_COUNT = 'cycle_first_stair_count'
- CYCLE_SECOND_STEP_SIZE = 'cycle_second_step_size'
- CYCLE_SECOND_STAIR_COUNT = 'cycle_second_stair_count'
- DECAY_STEP_SIZE = 'decay_step_size'
- CYCLE_MIN_LR = 'cycle_min_lr'
- CYCLE_MAX_LR = 'cycle_max_lr'
- DECAY_LR_RATE = 'decay_lr_rate'
- CYCLE_MIN_MOM = 'cycle_min_mom'
- CYCLE_MAX_MOM = 'cycle_max_mom'
- DECAY_MOM_RATE = 'decay_mom_rate'
- WARMUP_MIN_LR = 'warmup_min_lr'
- WARMUP_MAX_LR = 'warmup_max_lr'
- WARMUP_NUM_STEPS = 'warmup_num_steps'
- WARMUP_TYPE = 'warmup_type'
- WARMUP_LOG_RATE = 'log'
- WARMUP_LINEAR_RATE = 'linear'
- WARMUP_MIN_RATIO = 'warmup_min_ratio'
- COS_MIN_RATIO = 'cos_min_ratio'
- TOTAL_NUM_STEPS = 'total_num_steps'
- def add_tuning_arguments(parser):
- group = parser.add_argument_group('Convergence Tuning', 'Convergence tuning configurations')
- # LR scheduler
- group.add_argument('--lr_schedule', type=str, default=None, help='LR schedule for training.')
- # Learning rate range test
- group.add_argument("--lr_range_test_min_lr", type=float, default=0.001, help='Starting lr value.')
- group.add_argument("--lr_range_test_step_rate", type=float, default=1.0, help='scaling rate for LR range test.')
- group.add_argument("--lr_range_test_step_size", type=int, default=1000, help='training steps per LR change.')
- group.add_argument("--lr_range_test_staircase",
- type=bool,
- default=False,
- help='use staircase scaling for LR range test.')
- # OneCycle schedule
- group.add_argument("--cycle_first_step_size",
- type=int,
- default=1000,
- help='size of first step of 1Cycle schedule (training steps).')
- group.add_argument("--cycle_first_stair_count",
- type=int,
- default=-1,
- help='first stair count for 1Cycle schedule.')
- group.add_argument("--cycle_second_step_size",
- type=int,
- default=-1,
- help='size of second step of 1Cycle schedule (default first_step_size).')
- group.add_argument("--cycle_second_stair_count",
- type=int,
- default=-1,
- help='second stair count for 1Cycle schedule.')
- group.add_argument("--decay_step_size",
- type=int,
- default=1000,
- help='size of intervals for applying post cycle decay (training steps).')
- # 1Cycle LR
- group.add_argument("--cycle_min_lr", type=float, default=0.01, help='1Cycle LR lower bound.')
- group.add_argument("--cycle_max_lr", type=float, default=0.1, help='1Cycle LR upper bound.')
- group.add_argument("--decay_lr_rate", type=float, default=0.0, help='post cycle LR decay rate.')
- # 1Cycle Momentum
- group.add_argument('--cycle_momentum', default=False, action='store_true', help='Enable 1Cycle momentum schedule.')
- group.add_argument("--cycle_min_mom", type=float, default=0.8, help='1Cycle momentum lower bound.')
- group.add_argument("--cycle_max_mom", type=float, default=0.9, help='1Cycle momentum upper bound.')
- group.add_argument("--decay_mom_rate", type=float, default=0.0, help='post cycle momentum decay rate.')
- # Warmup LR
- group.add_argument('--warmup_min_lr', type=float, default=0, help='WarmupLR minimum/initial LR value')
- group.add_argument('--warmup_max_lr', type=float, default=0.001, help='WarmupLR maximum LR value.')
- group.add_argument('--warmup_num_steps', type=int, default=1000, help='WarmupLR step count for LR warmup.')
- group.add_argument('--warmup_type',
- type=str,
- default=WARMUP_LOG_RATE,
- help='WarmupLR increasing function during warmup')
- # WarmUP cos LR
- group.add_argument("--warmup_min_ratio", type=float, default=0.01, help='Cosine LR lower bound.')
- group.add_argument("--cos_min_ratio", type=float, default=0.01, help='Cosine LR lower bound.')
- return parser
- def parse_arguments():
- parser = argparse.ArgumentParser()
- parser = add_tuning_arguments(parser)
- lr_sched_args, unknown_args = parser.parse_known_args()
- return lr_sched_args, unknown_args
- def override_lr_range_test_params(args, params):
- if hasattr(args, LR_RANGE_TEST_MIN_LR) and args.lr_range_test_min_lr is not None:
- params[LR_RANGE_TEST_MIN_LR] = args.lr_range_test_min_lr
- if hasattr(args, LR_RANGE_TEST_STEP_RATE) and args.lr_range_test_step_rate is not None:
- params[LR_RANGE_TEST_STEP_RATE] = args.lr_range_test_step_rate
- if hasattr(args, LR_RANGE_TEST_STEP_SIZE) and args.lr_range_test_step_size is not None:
- params[LR_RANGE_TEST_STEP_SIZE] = args.lr_range_test_step_size
- if hasattr(args, LR_RANGE_TEST_STAIRCASE) and args.lr_range_test_staircase is not None:
- params[LR_RANGE_TEST_STAIRCASE] = args.lr_range_test_staircase
- def override_1cycle_params(args, params):
- if hasattr(args, CYCLE_FIRST_STEP_SIZE) and args.cycle_first_step_size is not None:
- params[CYCLE_FIRST_STEP_SIZE] = args.cycle_first_step_size
- if hasattr(args, CYCLE_FIRST_STAIR_COUNT) and args.cycle_first_stair_count is not None:
- params[CYCLE_FIRST_STAIR_COUNT] = args.cycle_first_stair_count
- if hasattr(args, CYCLE_SECOND_STEP_SIZE) and args.cycle_second_step_size is not None:
- params[CYCLE_SECOND_STEP_SIZE] = args.cycle_second_step_size
- if hasattr(args, CYCLE_SECOND_STAIR_COUNT) and args.cycle_second_stair_count is not None:
- params[CYCLE_SECOND_STAIR_COUNT] = args.cycle_second_stair_count
- if hasattr(args, DECAY_STEP_SIZE) and args.decay_step_size is not None:
- params[DECAY_STEP_SIZE] = args.decay_step_size
- # 1Cycle LR params
- if hasattr(args, CYCLE_MIN_LR) and args.cycle_min_lr is not None:
- params[CYCLE_MIN_LR] = args.cycle_min_lr
- if hasattr(args, CYCLE_MAX_LR) and args.cycle_max_lr is not None:
- params[CYCLE_MAX_LR] = args.cycle_max_lr
- if hasattr(args, DECAY_LR_RATE) and args.decay_lr_rate is not None:
- params[DECAY_LR_RATE] = args.decay_lr_rate
- # 1Cycle MOM params
- if hasattr(args, CYCLE_MIN_MOM) and args.cycle_min_mom is not None:
- params[CYCLE_MIN_MOM] = args.cycle_min_mom
- if hasattr(args, CYCLE_MAX_MOM) and args.cycle_max_mom is not None:
- params[CYCLE_MAX_MOM] = args.cycle_max_mom
- if hasattr(args, DECAY_MOM_RATE) and args.decay_mom_rate is not None:
- params[DECAY_MOM_RATE] = args.decay_mom_rate
- def override_warmupLR_params(args, params):
- if hasattr(args, WARMUP_MIN_LR) and args.warmup_min_lr is not None:
- params[WARMUP_MIN_LR] = args.warmup_min_lr
- if hasattr(args, WARMUP_MAX_LR) and args.warmup_max_lr is not None:
- params[WARMUP_MAX_LR] = args.warmup_max_lr
- if hasattr(args, WARMUP_NUM_STEPS) and args.warmup_num_steps is not None:
- params[WARMUP_NUM_STEPS] = args.warmup_num_steps
- if hasattr(args, WARMUP_TYPE) and args.warmup_type is not None:
- params[WARMUP_TYPE] = args.warmup_type
- def override_params(args, params):
- # LR range test params
- override_lr_range_test_params(args, params)
- # 1Cycle params
- override_1cycle_params(args, params)
- # WarmupLR params
- override_warmupLR_params(args, params)
- def get_config_from_args(args):
- if not hasattr(args, LR_SCHEDULE) or args.lr_schedule is None:
- return None, '--{} not specified on command line'.format(LR_SCHEDULE)
- if not args.lr_schedule in VALID_LR_SCHEDULES:
- return None, '{} is not supported LR schedule'.format(args.lr_schedule)
- config = {}
- config['type'] = args.lr_schedule
- config['params'] = {}
- if args.lr_schedule == LR_RANGE_TEST:
- override_lr_range_test_params(args, config['params'])
- elif args.lr_schedule == ONE_CYCLE:
- override_1cycle_params(args, config['params'])
- else:
- override_warmupLR_params(args, config['params'])
- return config, None
- def get_lr_from_config(config):
- if not 'type' in config:
- return None, 'LR schedule type not defined in config'
- if not 'params' in config:
- return None, 'LR schedule params not defined in config'
- lr_schedule = config['type']
- lr_params = config['params']
- if not lr_schedule in VALID_LR_SCHEDULES:
- return None, '{} is not a valid LR schedule'.format(lr_schedule)
- if lr_schedule == LR_RANGE_TEST:
- return lr_params[LR_RANGE_TEST_MIN_LR], ''
- if lr_schedule == ONE_CYCLE:
- return lr_params[CYCLE_MAX_LR], ''
- # Warmup LR
- return lr_params[WARMUP_MAX_LR], ''
- """
- Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped
- optimizer to see if requirement is satisfied.
- TODO: Looking under the hood to examine the wrapped optimizer is a hack that requires a better long-term fix.
- """
- def get_torch_optimizer(optimizer):
- if isinstance(optimizer, Optimizer):
- return optimizer
- if hasattr(optimizer, 'optimizer') and isinstance(optimizer.optimizer, Optimizer):
- return optimizer.optimizer
- raise TypeError('{} is not a subclass of torch.optim.Optimizer'.format(type(optimizer).__name__))
- class LRRangeTest(object):
- """Sets the learning rate of each parameter group according to
- learning rate range test (LRRT) policy. The policy increases learning
- rate starting from a base value with a constant frequency, as detailed in
- the paper `A disciplined approach to neural network hyper-parameters: Part1`_.
- LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to
- configure the LR boundaries for Cyclic LR schedules.
- LRRT changes the learning rate after every batch.
- `step` should be called after a batch has been used for training.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- lr_range_test_min_lr (float or list): Initial learning rate which is the
- lower boundary in the range test for each parameter group.
- lr_range_test_step_size (int): Interval of training steps to increase learning rate. Default: 2000
- lr_range_test_step_rate (float): Scaling rate for range test. Default: 1.0
- lr_range_test_staircase (bool): Scale in staircase fashion, rather than continuous. Default: False.
- last_batch_iteration (int): The index of the last batch. This parameter is used when
- resuming a training job. Since `step()` should be invoked after each
- batch instead of after each epoch, this number represents the total
- number of *batches* computed, not the total number of epochs computed.
- When last_batch_iteration=-1, the schedule is started from the beginning.
- Default: -1
- Example:
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
- >>> scheduler = LRRangeTest(optimizer)
- >>> data_loader = torch.utils.data.DataLoader(...)
- >>> for epoch in range(10):
- >>> for batch in data_loader:
- >>> train_batch(...)
- >>> scheduler.step()
- _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay:
- https://arxiv.org/abs/1803.09820
- """
- def __init__(self,
- optimizer: Optimizer,
- lr_range_test_min_lr: float = 1e-3,
- lr_range_test_step_size: int = 2000,
- lr_range_test_step_rate: float = 1.0,
- lr_range_test_staircase: bool = False,
- last_batch_iteration: int = -1):
- self.optimizer = get_torch_optimizer(optimizer)
- if isinstance(lr_range_test_min_lr, list) or isinstance(lr_range_test_min_lr, tuple):
- if len(lr_range_test_min_lr) != len(self.optimizer.param_groups):
- raise ValueError("expected {} lr_range_test_min_lr, got {}".format(len(self.optimizer.param_groups),
- len(lr_range_test_min_lr)))
- self.min_lr = list(lr_range_test_min_lr)
- else:
- self.min_lr = [lr_range_test_min_lr] * len(self.optimizer.param_groups)
- self.step_size = lr_range_test_step_size
- self.step_rate = lr_range_test_step_rate
- self.last_batch_iteration = last_batch_iteration
- self.staircase = lr_range_test_staircase
- self.interval_fn = self._staircase_interval if lr_range_test_staircase else self._continuous_interval
- if last_batch_iteration == -1:
- self._update_optimizer(self.min_lr)
- def _staircase_interval(self):
- return math.floor(float(self.last_batch_iteration + 1) / self.step_size)
- def _continuous_interval(self):
- return float(self.last_batch_iteration + 1) / self.step_size
- def _get_increase(self):
- return (1 + self.step_rate * self.interval_fn())
- def get_lr(self):
- lr_increase = self._get_increase()
- return [lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr]
- def get_last_lr(self):
- """ Return last computed learning rate by current scheduler.
- """
- assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
- return self._last_lr
- def _update_optimizer(self, group_lrs):
- for param_group, lr in zip(self.optimizer.param_groups, group_lrs):
- param_group['lr'] = lr
- def step(self, batch_iteration=None):
- if batch_iteration is None:
- batch_iteration = self.last_batch_iteration + 1
- self.last_batch_iteration = batch_iteration
- self._update_optimizer(self.get_lr())
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
- def state_dict(self):
- return {'last_batch_iteration': self.last_batch_iteration}
- def load_state_dict(self, sd):
- self.last_batch_iteration = sd['last_batch_iteration']
- class OneCycle(object):
- """Sets the learning rate of each parameter group according to
- 1Cycle learning rate policy (1CLR). 1CLR is a variation of the
- Cyclical Learning Rate (CLR) policy that involves one cycle followed by
- decay. The policy simultaneously cycles the learning rate (and momentum)
- between two boundaries with a constant frequency, as detailed in
- the paper `A disciplined approach to neural network hyper-parameters`_.
- 1CLR policy changes the learning rate after every batch.
- `step` should be called after a batch has been used for training.
- This implementation was adapted from the github repo: `pytorch/pytorch`_
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- cycle_min_lr (float or list): Initial learning rate which is the
- lower boundary in the cycle for each parameter group.
- cycle_max_lr (float or list): Upper learning rate boundaries in the cycle
- for each parameter group. Functionally,
- it defines the cycle amplitude (cycle_max_lr - cycle_min_lr).
- The lr at any cycle is the sum of cycle_min_lr
- and some scaling of the amplitude; therefore
- cycle_max_lr may not actually be reached depending on
- scaling function.
- decay_lr_rate(float): Decay rate for learning rate. Default: 0.
- cycle_first_step_size (int): Number of training iterations in the
- increasing half of a cycle. Default: 2000
- cycle_second_step_size (int): Number of training iterations in the
- decreasing half of a cycle. If cycle_second_step_size is None,
- it is set to cycle_first_step_size. Default: None
- cycle_first_stair_count(int): Number of stairs in first half of cycle phase. This means
- lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
- cycle_second_stair_count(int): Number of stairs in second half of cycle phase. This means
- lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
- decay_step_size (int): Intervals for applying decay in decay phase. Default: 0, means no decay.
- cycle_momentum (bool): If ``True``, momentum is cycled inversely
- to learning rate between 'cycle_min_mom' and 'cycle_max_mom'.
- Default: True
- cycle_min_mom (float or list): Initial momentum which is the
- lower boundary in the cycle for each parameter group.
- Default: 0.8
- cycle_max_mom (float or list): Upper momentum boundaries in the cycle
- for each parameter group. Functionally,
- it defines the cycle amplitude (cycle_max_mom - cycle_min_mom).
- The momentum at any cycle is the difference of cycle_max_mom
- and some scaling of the amplitude; therefore
- cycle_min_mom may not actually be reached depending on
- scaling function. Default: 0.9
- decay_mom_rate (float): Decay rate for momentum. Default: 0.
- last_batch_iteration (int): The index of the last batch. This parameter is used when
- resuming a training job. Since `step()` should be invoked after each
- batch instead of after each epoch, this number represents the total
- number of *batches* computed, not the total number of epochs computed.
- When last_batch_iteration=-1, the schedule is started from the beginning.
- Default: -1
- Example:
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
- >>> scheduler = OneCycle(optimizer, 0.0001, 0.0010)
- >>> data_loader = torch.utils.data.DataLoader(...)
- >>> for epoch in range(10):
- >>> for batch in data_loader:
- >>> train_batch(...)
- >>> scheduler.step()
- .. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820
- """
- def __init__(self,
- optimizer,
- cycle_min_lr,
- cycle_max_lr,
- decay_lr_rate=0.,
- cycle_first_step_size=2000,
- cycle_second_step_size=None,
- cycle_first_stair_count=0,
- cycle_second_stair_count=None,
- decay_step_size=0,
- cycle_momentum=True,
- cycle_min_mom=0.8,
- cycle_max_mom=0.9,
- decay_mom_rate=0.,
- last_batch_iteration=-1):
- self.optimizer = get_torch_optimizer(optimizer)
- # Initialize cycle shape
- self._initialize_cycle(cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count,
- cycle_second_stair_count, decay_step_size)
- # Initialize cycle lr
- self._initialize_lr(self.optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration)
- # Initialize cyclic momentum
- self.cycle_momentum = cycle_momentum
- if cycle_momentum:
- self._initialize_momentum(self.optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate,
- last_batch_iteration)
- # Initialize batch iteration tracker
- self.last_batch_iteration = last_batch_iteration
- # Configure cycle shape
- def _initialize_cycle(self, cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count,
- cycle_second_stair_count, decay_step_size):
- cycle_first_step_size = float(cycle_first_step_size)
- cycle_second_step_size = float(
- cycle_second_step_size) if cycle_second_step_size is not None else cycle_first_step_size
- self.total_size = cycle_first_step_size + cycle_second_step_size
- self.step_ratio = cycle_first_step_size / self.total_size
- self.first_stair_count = cycle_first_stair_count
- self.second_stair_count = cycle_first_stair_count if cycle_second_stair_count is None else cycle_second_stair_count
- self.decay_step_size = decay_step_size
- if math.isclose(self.decay_step_size, 0):
- self.skip_lr_decay = True
- self.skip_mom_decay = True
- else:
- self.skip_lr_decay = False
- self.skip_mom_decay = False
- # Configure lr schedule
- def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration):
- self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups)
- if last_batch_iteration == -1:
- for lr, group in zip(self.min_lrs, optimizer.param_groups):
- group['lr'] = lr
- self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups)
- self.decay_lr_rate = decay_lr_rate
- if math.isclose(self.decay_lr_rate, 0):
- self.skip_lr_decay = True
- # Configure momentum schedule
- def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration):
- if 'betas' not in optimizer.defaults:
- optimizer_name = type(optimizer).__name__
- logger.warn(
- f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
- )
- self.cycle_momentum = False
- return
- self.decay_mom_rate = decay_mom_rate
- self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups)
- self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups)
- if last_batch_iteration == -1:
- for momentum, group in zip(self.min_moms, optimizer.param_groups):
- group['betas'] = momentum
- if math.isclose(self.decay_mom_rate, 0):
- self.skip_mom_decay = True
- def _get_scale_factor(self):
- batch_iteration = (self.last_batch_iteration + 1)
- cycle = math.floor(1 + batch_iteration / self.total_size)
- x = 1. + batch_iteration / self.total_size - cycle
- if x <= self.step_ratio:
- scale_factor = x / self.step_ratio
- else:
- scale_factor = (x - 1) / (self.step_ratio - 1)
- return scale_factor
- def _get_cycle_mom(self):
- scale_factor = self._get_scale_factor()
- momentums = []
- for base_betas, max_betas in zip(self.min_moms, self.max_moms):
- cycle_min_mom = base_betas[0]
- cycle_max_mom = max_betas[0]
- base_height = (cycle_max_mom - cycle_min_mom) * scale_factor
- momentum = cycle_max_mom - base_height
- momentums.append((momentum, base_betas[1]))
- return momentums
- def _get_cycle_lr(self):
- scale_factor = self._get_scale_factor()
- lrs = []
- for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
- base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
- lr = cycle_min_lr + base_height
- lrs.append(lr)
- return lrs
- def _get_decay_mom(self, decay_batch_iteration):
- if self.skip_mom_decay:
- return self.max_moms
- decay_interval = decay_batch_iteration / self.decay_step_size
- mom_decay_factor = (1 + self.decay_mom_rate * decay_interval)
- momentums = [(beta0 * mom_decay_factor, beta1) for beta0, beta1 in self.max_moms]
- return momentums
- def _get_decay_lr(self, decay_batch_iteration):
- """Calculates the learning rate at batch index. This function is used
- after the cycle completes and post cycle decaying of lr/mom is enabled.
- This function treats `self.last_batch_iteration` as the last batch index.
- """
- if self.skip_lr_decay:
- return self.min_lrs
- decay_interval = decay_batch_iteration / self.decay_step_size
- lr_decay_factor = (1 + self.decay_lr_rate * decay_interval)
- lrs = [cycle_min_lr / lr_decay_factor for cycle_min_lr in self.min_lrs]
- return lrs
- def get_lr(self):
- """Calculates the learning rate at batch index. This function treats
- `self.last_batch_iteration` as the last batch index.
- """
- if self.last_batch_iteration < self.total_size:
- return self._get_cycle_lr()
- return self._get_decay_lr(self.last_batch_iteration - self.total_size + 1)
- def get_mom(self):
- """Calculates the momentum at batch index. This function treats
- `self.last_batch_iteration` as the last batch index.
- """
- if not self.cycle_momentum:
- return None
- if self.last_batch_iteration < self.total_size:
- return self._get_cycle_mom()
- return self._get_decay_mom(self.last_batch_iteration - self.total_size + 1)
- def get_last_lr(self):
- """ Return last computed learning rate by current scheduler.
- """
- assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
- return self._last_lr
- def step(self, batch_iteration=None):
- """ Updates the optimizer with the learning rate for the last batch index.
- `self.last_batch_iteration` is treated as the last batch index.
- If self.cycle_momentum is true, also updates optimizer momentum.
- """
- if batch_iteration is None:
- batch_iteration = self.last_batch_iteration + 1
- self.last_batch_iteration = batch_iteration
- for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
- param_group['lr'] = lr
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
- if self.cycle_momentum:
- momentums = self.get_mom()
- for param_group, momentum in zip(self.optimizer.param_groups, momentums):
- param_group['betas'] = momentum
- def state_dict(self):
- return {'last_batch_iteration': self.last_batch_iteration}
- def load_state_dict(self, sd):
- self.last_batch_iteration = sd['last_batch_iteration']
- class WarmupLR(object):
- """Increase the learning rate of each parameter group from min lr to max lr
- over warmup_num_steps steps, and then fix at max lr.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- warmup_min_lr (float or list): minimum learning rate. Default: 0
- warmup_max_lr (float or list): maximum learning rate. Default: 0.001
- warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
- warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
- last_batch_iteration (int): The index of the last batch. Default: -1.
- Example:
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
- >>> scheduler = WarmupLR(optimizer)
- >>> data_loader = torch.utils.data.DataLoader(...)
- >>> for epoch in range(10):
- >>> for batch in data_loader:
- >>> train_batch(...)
- >>> scheduler.step()
- """
- def __init__(self,
- optimizer: Optimizer,
- warmup_min_lr: float = 0.0,
- warmup_max_lr: float = 0.001,
- warmup_num_steps: int = 1000,
- warmup_type: str = WARMUP_LOG_RATE,
- last_batch_iteration: int = -1):
- self.optimizer = get_torch_optimizer(optimizer)
- self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr")
- self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr")
- self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)]
- self.warmup_num_steps = max(2, warmup_num_steps)
- # Currently only support linear and log function
- if warmup_type not in {WARMUP_LOG_RATE, WARMUP_LINEAR_RATE}:
- logger.warning(f"Using unknown warmup_type: {warmup_type}. The increasing function "
- f"is set to default (log)")
- warmup_type = WARMUP_LOG_RATE
- self.warmup_type = warmup_type
- self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
- self.last_batch_iteration = last_batch_iteration
- def get_lr(self):
- if self.last_batch_iteration < 0:
- logger.warning("Attempting to get learning rate from scheduler before it has started")
- return [0.0]
- gamma = self._get_gamma()
- return [min_lr + (delta_lr * gamma) for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)]
- def get_last_lr(self):
- """ Return last computed learning rate by current scheduler.
- """
- assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
- return self._last_lr
- def step(self, last_batch_iteration=None):
- if last_batch_iteration is None:
- last_batch_iteration = self.last_batch_iteration + 1
- self.last_batch_iteration = last_batch_iteration
- for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
- param_group['lr'] = lr
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
- def state_dict(self):
- return {'last_batch_iteration': self.last_batch_iteration}
- def load_state_dict(self, sd):
- self.last_batch_iteration = sd['last_batch_iteration']
- def _get_gamma(self):
- if self.last_batch_iteration < self.warmup_num_steps:
- if self.warmup_type == WARMUP_LOG_RATE:
- return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
- elif self.warmup_type == WARMUP_LINEAR_RATE:
- return self.last_batch_iteration / self.warmup_num_steps
- return 1.0
- def _format_param(self, optimizer, param_value, param_name):
- if isinstance(param_value, list) or isinstance(param_value, tuple):
- if len(param_value) != len(optimizer.param_groups):
- raise ValueError("expected {} value for {}, got {}".format(len(optimizer.param_groups), param_name,
- FileNotFoundError(param_value)))
- return list(param_value)
- return [param_value] * len(optimizer.param_groups)
- class WarmupDecayLR(WarmupLR):
- """Increase the learning rate of each parameter group from min lr to max lr
- over warmup_num_steps steps, and then decay at linear rate over the remaining training steps.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- total_num_steps (int): total number of training steps
- warmup_min_lr (float or list): minimum learning rate. Default: 0
- warmup_max_lr (float or list): maximum learning rate. Default: 0.001
- warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
- warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
- last_batch_iteration (int): The index of the last batch. Default: -1.
- Example:
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
- >>> scheduler = WarmupDecayLR(optimizer, 1000000)
- >>> data_loader = torch.utils.data.DataLoader(...)
- >>> for epoch in range(10):
- >>> for batch in data_loader:
- >>> train_batch(...)
- >>> scheduler.step()
- """
- def __init__(self,
- optimizer: Optimizer,
- total_num_steps: int,
- warmup_min_lr: float = 0.0,
- warmup_max_lr: float = 0.001,
- warmup_num_steps: int = 1000,
- warmup_type: str = WARMUP_LOG_RATE,
- last_batch_iteration: int = -1):
- self.total_num_steps = total_num_steps
- super(WarmupDecayLR, self).__init__(optimizer, warmup_min_lr, warmup_max_lr, warmup_num_steps, warmup_type,
- last_batch_iteration)
- if self.total_num_steps < self.warmup_num_steps:
- logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format(
- total_num_steps, warmup_num_steps))
- def _get_gamma(self):
- if self.last_batch_iteration < self.warmup_num_steps:
- if self.warmup_type == WARMUP_LOG_RATE:
- return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
- elif self.warmup_type == WARMUP_LINEAR_RATE:
- return self.last_batch_iteration / self.warmup_num_steps
- return max(
- 0.0,
- float(self.total_num_steps - self.last_batch_iteration) /
- float(max(1.0, self.total_num_steps - self.warmup_num_steps)))
- class WarmupCosineLR(object):
- """Increase the learning rate of each parameter group from min lr ratio to max lr ratio
- over warmup_num_steps steps, and then decay at cosine rate over the remaining training steps to min cosine ratio.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- total_num_steps (int): total number of training steps
- warmup_min_ratio (float or list): warmup start learning rate ratio. Default: 0
- warmup_num_steps (int): number of steps to warm up from warmup_min_ratio to 1.0. Default: 1000
- warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
- cos_min_ratio (float): cosine end learning rate ratio. Default: 0.0001
- last_batch_iteration (int): The index of the last batch. Default: -1.
- Example:
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
- >>> scheduler = WarmupCosineLR(optimizer, 1000000)
- >>> data_loader = torch.utils.data.DataLoader(...)
- >>> for epoch in range(10):
- >>> for batch in data_loader:
- >>> train_batch(...)
- >>> scheduler.step()
- """
- def __init__(self,
- optimizer: Optimizer,
- total_num_steps: int,
- warmup_min_ratio: float = 0.0,
- warmup_num_steps: int = 1000,
- cos_min_ratio: float = 0.0001,
- warmup_type: str = WARMUP_LOG_RATE,
- last_batch_iteration: int = -1):
- self.optimizer = get_torch_optimizer(optimizer)
- self.total_num_steps = total_num_steps
- self.last_batch_iteration = last_batch_iteration
- self.cos_min_ratio = cos_min_ratio
- self.warmup_type = warmup_type
- self.warmup_min_ratio = warmup_min_ratio
- self.warmup_num_steps = max(2, warmup_num_steps)
- self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
- if self.total_num_steps < self.warmup_num_steps:
- logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format(
- total_num_steps, warmup_num_steps))
- self.org_lrs = [group['lr'] for group in self.optimizer.param_groups]
- def get_lr_ratio(self):
- if self.last_batch_iteration < 0:
- logger.warning("Attempting to get learning rate from scheduler before it has started")
- return [0.0]
- if self.last_batch_iteration < self.warmup_num_steps:
- if self.warmup_type == WARMUP_LOG_RATE:
- ratio = self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
- elif self.warmup_type == WARMUP_LINEAR_RATE:
- ratio = self.last_batch_iteration / self.warmup_num_steps
- ratio_delta = 1. - self.warmup_min_ratio
- ratio = self.warmup_min_ratio + ratio * ratio_delta
- return ratio
- real_last_step = self.last_batch_iteration - self.warmup_num_steps + 1
- real_total_steps = self.total_num_steps - self.warmup_num_steps
- ratio_delta = 1. - self.cos_min_ratio
- ratio = (1 + math.cos(math.pi * real_last_step / real_total_steps)) / 2
- ratio = max(0.0, self.cos_min_ratio + ratio_delta * ratio)
- return ratio
- def step(self, last_batch_iteration=None):
- if last_batch_iteration is None:
- last_batch_iteration = self.last_batch_iteration + 1
- self.last_batch_iteration = last_batch_iteration
- lrs = self.get_lr()
- for param_group, lr in zip(self.optimizer.param_groups, lrs):
- param_group['lr'] = lr
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
- def get_lr(self):
- if self.last_batch_iteration < 0:
- logger.warning("Attempting to get learning rate from scheduler before it has started")
- return [0.0]
- lr_ratio = self.get_lr_ratio()
- return [org_lr * lr_ratio for org_lr in self.org_lrs]
- def get_last_lr(self):
- """ Return last computed learning rate by current scheduler.
- """
- assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
- return self._last_lr
- def state_dict(self):
- return {'last_batch_iteration': self.last_batch_iteration}
- def load_state_dict(self, sd):
- self.last_batch_iteration = sd['last_batch_iteration']
- def _format_param(self, optimizer, param_value, param_name):
- if isinstance(param_value, list) or isinstance(param_value, tuple):
- if len(param_value) != len(optimizer.param_groups):
- raise ValueError("expected {} value for {}, got {}".format(len(optimizer.param_groups), param_name,
- FileNotFoundError(param_value)))
- return list(param_value)
- return [param_value] * len(optimizer.param_groups)
|