123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- import deepspeed
- import pytest
- from unit.common import DistributedTest
- from unit.simple_model import SimpleModel, random_dataloader
- from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, LR_RANGE_TEST_MIN_LR, LR_RANGE_TEST_STEP_RATE, LR_RANGE_TEST_STEP_SIZE, LR_RANGE_TEST_STAIRCASE
- from deepspeed.runtime.lr_schedules import WARMUP_LR, WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS, WARMUP_TYPE, WARMUP_LOG_RATE, WARMUP_LINEAR_RATE
- from deepspeed.runtime.lr_schedules import ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR, CYCLE_FIRST_STEP_SIZE, DECAY_LR_RATE, DECAY_STEP_SIZE
- from deepspeed.runtime.lr_schedules import CYCLE_MIN_MOM, CYCLE_MAX_MOM, DECAY_MOM_RATE
- from deepspeed.runtime.lr_schedules import WARMUP_DECAY_LR, TOTAL_NUM_STEPS
- def _verify_continuous_decrease(values):
- for i in range(len(values) - 1):
- assert values[i] > values[i + 1]
- def _verify_continuous_increase(values):
- for i in range(len(values) - 1):
- assert values[i] < values[i + 1]
- def _verify_staircase_increase(values, step_size):
- num_values = len(values)
- for i in range(0, num_values, step_size):
- j = min(i + step_size, num_values)
- assert all([values[i] == v for v in values[i:j]])
- @pytest.mark.parametrize("scheduler_type,params", [(WARMUP_LR, {}),
- (WARMUP_DECAY_LR, {
- WARMUP_NUM_STEPS: 10,
- TOTAL_NUM_STEPS: 20
- }), (ONE_CYCLE, {
- CYCLE_MIN_LR: 0,
- CYCLE_MAX_LR: 0.1
- }), (LR_RANGE_TEST, {})])
- class TestGetLrBeforeTrain(DistributedTest):
- world_size = 1
- def test(self, scheduler_type, params):
- config_dict = {
- "train_batch_size": 2,
- "steps_per_print": 1,
- "optimizer": {
- "type": "Adam",
- "params": {
- "lr": 0.00015
- },
- },
- "scheduler": {
- "type": scheduler_type,
- "params": params
- },
- "gradient_clipping": 1.0
- }
- hidden_dim = 10
- model = SimpleModel(hidden_dim, empty_grad=False)
- model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
- model=model,
- model_parameters=model.parameters())
- data_loader = random_dataloader(model=model,
- total_samples=50,
- hidden_dim=hidden_dim,
- device=model.device,
- dtype=torch.float)
- for n, batch in enumerate(data_loader):
- # get lr before training starts
- lr_scheduler.get_lr()
- loss = model(batch[0], batch[1])
- model.backward(loss)
- model.step()
- @pytest.mark.parametrize("warmup_num_steps", [10, 15, 19, 33])
- @pytest.mark.parametrize("warmup_type", [WARMUP_LOG_RATE, WARMUP_LINEAR_RATE])
- class TestLrSchedule(DistributedTest):
- world_size = 1
- def test_lr_warmup_schedule(self, warmup_num_steps, warmup_type):
- config_dict = {
- "train_batch_size": 2,
- "steps_per_print": 1,
- "optimizer": {
- "type": "Adam",
- "params": {
- "lr": 0.00015
- },
- },
- "scheduler": {
- "type": WARMUP_LR,
- "params": {
- WARMUP_MIN_LR: 0.1,
- WARMUP_MAX_LR: 0.2,
- WARMUP_NUM_STEPS: warmup_num_steps,
- WARMUP_TYPE: warmup_type,
- }
- },
- "gradient_clipping": 1.0
- }
- schedule_params = config_dict["scheduler"]["params"]
- total_num_steps = 2 * warmup_num_steps
- hidden_dim = 10
- model = SimpleModel(hidden_dim, empty_grad=False)
- model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
- model=model,
- model_parameters=model.parameters())
- data_loader = random_dataloader(model=model,
- total_samples=total_num_steps * 2,
- hidden_dim=hidden_dim,
- device=model.device,
- dtype=torch.float)
- step_lrs = []
- for n, batch in enumerate(data_loader):
- loss = model(batch[0], batch[1])
- model.backward(loss)
- model.step()
- step_lrs.append(lr_scheduler.get_lr())
- # Verify initial lr
- assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]]
- # Verify warmup completion
- warmup_num_steps = schedule_params[WARMUP_NUM_STEPS]
- warmup_max_lr = [schedule_params[WARMUP_MAX_LR]]
- assert step_lrs[warmup_num_steps] == warmup_max_lr
- # Verify post-warmup completion
- assert all([warmup_max_lr == lr for lr in step_lrs[warmup_num_steps:]])
- def test_lr_warmup_decay_schedule(self, warmup_num_steps, warmup_type):
- config_dict = {
- "train_batch_size": 2,
- "steps_per_print": 1,
- "optimizer": {
- "type": "Adam",
- "params": {
- "lr": 0.00015
- },
- },
- "scheduler": {
- "type": WARMUP_DECAY_LR,
- "params": {
- WARMUP_MIN_LR: 0.1,
- WARMUP_MAX_LR: 0.2,
- WARMUP_NUM_STEPS: warmup_num_steps,
- TOTAL_NUM_STEPS: warmup_num_steps * 2,
- WARMUP_TYPE: warmup_type
- }
- },
- "gradient_clipping": 1.0
- }
- schedule_params = config_dict["scheduler"]["params"]
- total_num_steps = schedule_params[TOTAL_NUM_STEPS]
- hidden_dim = 10
- model = SimpleModel(hidden_dim, empty_grad=False)
- model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
- model=model,
- model_parameters=model.parameters())
- data_loader = random_dataloader(model=model,
- total_samples=total_num_steps * 2,
- hidden_dim=hidden_dim,
- device=model.device,
- dtype=torch.float)
- step_lrs = []
- for n, batch in enumerate(data_loader):
- loss = model(batch[0], batch[1])
- model.backward(loss)
- model.step()
- step_lrs.append(lr_scheduler.get_lr())
- # Verify initial lr
- assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]]
- # Verify lr at warmup completion
- warmup_num_steps = schedule_params[WARMUP_NUM_STEPS]
- warmup_max_lr = [schedule_params[WARMUP_MAX_LR]]
- assert step_lrs[warmup_num_steps] == warmup_max_lr
- # Verify decay phase
- previous_lr = warmup_max_lr
- for lr in step_lrs[warmup_num_steps + 1:]:
- assert lr < previous_lr
- previous_lr = lr
- @pytest.mark.parametrize("scheduler_type,params", [(WARMUP_LR, {}),
- (WARMUP_DECAY_LR, {
- WARMUP_NUM_STEPS: 5,
- TOTAL_NUM_STEPS: 10
- }),
- (ONE_CYCLE, {
- CYCLE_MIN_LR: 0,
- CYCLE_MAX_LR: 0.1,
- CYCLE_FIRST_STEP_SIZE: 5,
- DECAY_STEP_SIZE: 5
- }),
- (LR_RANGE_TEST, {
- LR_RANGE_TEST_MIN_LR: 1e-4,
- LR_RANGE_TEST_STEP_SIZE: 1
- })])
- class TestSchedulerOptimizerParity(DistributedTest):
- world_size = 1
- def test(self, scheduler_type, params):
- config_dict = {
- "train_batch_size": 2,
- "steps_per_print": 1,
- "optimizer": {
- "type": "Adam",
- "params": {
- "lr": 0.00015
- },
- },
- "scheduler": {
- "type": scheduler_type,
- "params": params
- },
- "gradient_clipping": 1.0
- }
- hidden_dim = 10
- model = SimpleModel(hidden_dim, empty_grad=False)
- model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
- model=model,
- model_parameters=model.parameters())
- data_loader = random_dataloader(model=model,
- total_samples=50,
- hidden_dim=hidden_dim,
- device=model.device,
- dtype=torch.float)
- for n, batch in enumerate(data_loader):
- loss = model(batch[0], batch[1])
- model.backward(loss)
- model.step()
- assert lr_scheduler.get_lr() == model.get_lr()
- @pytest.mark.parametrize("min_lr, step_rate, step_size, staircase",
- [(1e-4, 1e-5, 1, True),
- (1e-5, 1e-5, 1, False),
- (1e-4, 1e-3, 10, True),
- (1e-3, 1e-3, 10, False),
- (1e-2, 1e-2, 19, True),
- (1e-2, 1e-2, 19, False)
- ])# yapf: disable
- class TestLrRange(DistributedTest):
- world_size = 1
- def test(self, min_lr, step_rate, step_size, staircase):
- config_dict = {
- "train_batch_size": 2,
- "steps_per_print": 1,
- "optimizer": {
- "type": "Adam",
- "params": {
- "lr": 0.00015
- },
- },
- "scheduler": {
- "type": LR_RANGE_TEST,
- "params": {
- LR_RANGE_TEST_MIN_LR: min_lr,
- LR_RANGE_TEST_STEP_RATE: step_rate,
- LR_RANGE_TEST_STEP_SIZE: step_size,
- LR_RANGE_TEST_STAIRCASE: staircase
- }
- },
- "gradient_clipping": 1.0
- }
- hidden_dim = 10
- model = SimpleModel(hidden_dim, empty_grad=False)
- model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
- model=model,
- model_parameters=model.parameters())
- data_loader = random_dataloader(model=model,
- total_samples=max(50, step_size * 2),
- hidden_dim=hidden_dim,
- device=model.device,
- dtype=torch.float)
- step_lrs = []
- for _, batch in enumerate(data_loader):
- step_lrs.extend(lr_scheduler.get_lr())
- loss = model(batch[0], batch[1])
- model.backward(loss)
- model.step()
- # Verify starting lr
- assert step_lrs[0] == min_lr
- if staircase:
- # Verify staircase increasing lr
- _verify_staircase_increase(step_lrs, step_size)
- else:
- # Verify continuous increasing lr
- _verify_continuous_increase(step_lrs)
- class TestOneCycle(DistributedTest):
- world_size = 1
- @pytest.mark.parametrize("min_lr, max_lr, decay_rate, cycle_step_size, decay_step_size",
- [
- (1e-5, 1e-2, 1e-3, 10, 10),
- (1e-3, 1e-1, 0, 21, 21),
- (1e-5, 1e-2, 1e-3, 10, 10),
- (1e-3, 1e-1, 1e-1, 21, 21),
- (1e-5, 1e-1, 0, 10, 0),
- ]) # yapf: disable
- def test_lr(self, min_lr, max_lr, decay_rate, cycle_step_size, decay_step_size):
- config_dict = {
- "train_batch_size": 2,
- "steps_per_print": 1,
- "optimizer": {
- "type": "Adam",
- "params": {
- "lr": 0.00015
- },
- },
- "scheduler": {
- "type": ONE_CYCLE,
- "params": {
- CYCLE_MIN_LR: min_lr,
- CYCLE_MAX_LR: max_lr,
- DECAY_LR_RATE: decay_rate,
- CYCLE_FIRST_STEP_SIZE: cycle_step_size,
- DECAY_STEP_SIZE: decay_step_size
- }
- },
- "gradient_clipping": 1.0
- }
- hidden_dim = 10
- model = SimpleModel(hidden_dim, empty_grad=False)
- model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
- model=model,
- model_parameters=model.parameters())
- data_loader = random_dataloader(model=model,
- total_samples=max(50, cycle_step_size * 3),
- hidden_dim=hidden_dim,
- device=model.device,
- dtype=torch.float)
- step_lrs = []
- for _, batch in enumerate(data_loader):
- step_lrs.extend(lr_scheduler.get_lr())
- loss = model(batch[0], batch[1])
- model.backward(loss)
- model.step()
- # Verify starting lr
- assert step_lrs[0] == min_lr
- # Verify peak lr
- assert step_lrs[cycle_step_size] == max_lr
- # Verify increasing phase
- _verify_continuous_increase(step_lrs[:cycle_step_size])
- # Verify decreasing phase
- _verify_continuous_decrease(step_lrs[cycle_step_size:(cycle_step_size * 2)])
- # Verify decay phase
- if decay_rate > 0:
- _verify_continuous_decrease(step_lrs[(cycle_step_size * 2):])
- @pytest.mark.parametrize("min_mom, max_mom, decay_rate, step_size",
- [
- (0.08, 0.09, 1e-3, 10),
- (0.08, 0.09, 0, 21),
- (0.08, 0.09, 1e-3, 10),
- (0.08, 0.09, 0, 21),
- ]) # yapf: disable
- def test_mom(self, min_mom, max_mom, decay_rate, step_size):
- config_dict = {
- "train_batch_size": 2,
- "steps_per_print": 1,
- "optimizer": {
- "type": "Adam",
- "params": {
- "lr": 0.00015
- },
- },
- "scheduler": {
- "type": ONE_CYCLE,
- "params": {
- CYCLE_MIN_LR: 1e-3,
- CYCLE_MAX_LR: 1e-2,
- CYCLE_MIN_MOM: min_mom,
- CYCLE_MAX_MOM: max_mom,
- DECAY_MOM_RATE: decay_rate,
- CYCLE_FIRST_STEP_SIZE: step_size,
- DECAY_STEP_SIZE: step_size
- }
- },
- "gradient_clipping": 1.0
- }
- hidden_dim = 10
- model = SimpleModel(hidden_dim, empty_grad=False)
- model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
- model=model,
- model_parameters=model.parameters())
- data_loader = random_dataloader(model=model,
- total_samples=max(50, step_size * 3),
- hidden_dim=hidden_dim,
- device=model.device,
- dtype=torch.float)
- step_moms = []
- for _, batch in enumerate(data_loader):
- step_moms.append(lr_scheduler.get_mom())
- loss = model(batch[0], batch[1])
- model.backward(loss)
- model.step()
- # Verify starting lr
- assert step_moms[0][0][0] == max_mom
- # Verify peak lr
- assert step_moms[step_size][0][0] == min_mom
- # Verify decreasing phase
- _verify_continuous_decrease(step_moms[:step_size])
- # Verify increasing phase
- _verify_continuous_increase(step_moms[step_size:(step_size * 2)])
- # Verify decay phase
- if decay_rate > 0:
- _verify_continuous_increase(step_moms[(step_size * 2):])
|