123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import math
- from deepspeed.utils import logger
- # from deepspeed.runtime.lr_schedules import WarmupLR
- from ..constants import *
- #####based on the paper random-ltd: https://arxiv.org/abs/2211.11586
- class BaseScheduler(object):
- def __init__(self):
- self.state = {}
- def __fixed_root_get_value(self, global_steps, root_degree=None):
- s_state = self.state[RANDOM_LTD_SCHEDULE_CONFIG]
- if root_degree is None:
- root_degree = s_state['root_degree']
- next_seq = (float(global_steps) / s_state[RANDOM_LTD_REQUIRE_STEP])**(1.0 / root_degree)
- next_seq = math.floor(next_seq * (self.state[RANDOM_LTD_MAX_VALUE] - self.state[RANDOM_LTD_MIN_VALUE]) +
- self.state[RANDOM_LTD_MIN_VALUE])
- next_seq -= (next_seq % s_state[RANDOM_LTD_INCREASE_STEP])
- next_seq = min(next_seq, self.state[RANDOM_LTD_MAX_VALUE])
- return next_seq
- def get_value(self, global_steps):
- if self.state[RANDOM_LTD_SCHEDULER_TYPE] == 'fixed_linear':
- return self.__fixed_root_get_value(global_steps, 1)
- else:
- raise RuntimeError('Unsupported random LTD schedule type')
- class RandomLTDScheduler(BaseScheduler):
- def __init__(self, config):
- super().__init__()
- self.model_layer_num = config[RANDOM_LTD_TOTAL_LAYER_NUM]
- self.random_ltd_layer_num = config[RANDOM_LTD_LAYER_NUM]
- self.config_schedule = config[RANDOM_LTD_SCHEDULER]
- self.global_batch_size = config[RANDOM_LTD_GLOBAL_BATCH_SIZE]
- self.reset_to_init()
- if config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]:
- logger.warning("**********Work In Progress************")
- raise NotImplementedError
- self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = 0
- # self.first_step = True
- def get_total_layer_tokens(self, train_iters):
- for step in range(train_iters):
- self.update_seq(step)
- return self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS]
- def reset_to_init(self):
- if self.config_schedule is not None:
- self.state[RANDOM_LTD_MIN_VALUE] = self.config_schedule[RANDOM_LTD_MIN_VALUE]
- self.state[RANDOM_LTD_MAX_VALUE] = self.config_schedule[RANDOM_LTD_MAX_VALUE]
- self.state[RANDOM_LTD_CURRENT_VALUE] = self.config_schedule[RANDOM_LTD_MIN_VALUE]
- self.state[RANDOM_LTD_SCHEDULE_CONFIG] = self.config_schedule[RANDOM_LTD_SCHEDULE_CONFIG]
- self.state[RANDOM_LTD_SCHEDULER_TYPE] = self.config_schedule[RANDOM_LTD_SCHEDULER_TYPE]
- self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = 0
- self.state[RANDOM_LTD_CURR_STEP] = -1
- def get_current_seq(self):
- return self.state[RANDOM_LTD_CURRENT_VALUE]
- def set_current_seq(self, seq_length):
- self.state[RANDOM_LTD_CURRENT_VALUE] = seq_length
- def get_random_ltd_layer_num(self):
- return self.random_ltd_layer_num
- def get_state(self):
- return self.state
- def set_state(self, state):
- self.state = state
- def update_seq(self, global_steps):
- if self.state[RANDOM_LTD_CURRENT_VALUE] < self.state[RANDOM_LTD_MAX_VALUE]:
- self.state[RANDOM_LTD_CURRENT_VALUE] = self.get_value(global_steps)
- if global_steps != self.state[RANDOM_LTD_CURR_STEP]:
- self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] += self.global_batch_size*(self.state[RANDOM_LTD_CURRENT_VALUE] * self.random_ltd_layer_num \
- + self.state[RANDOM_LTD_MAX_VALUE] * (self.model_layer_num - self.random_ltd_layer_num))
- self.state[RANDOM_LTD_CURR_STEP] = global_steps
- def state_dict(self):
- return {
- RANDOM_LTD_CONSUMED_LAYER_TOKENS: self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS],
- RANDOM_LTD_CURR_STEP: self.state[RANDOM_LTD_CURR_STEP],
- RANDOM_LTD_CURRENT_VALUE: self.state[RANDOM_LTD_CURRENT_VALUE],
- RANDOM_LTD_MIN_VALUE: self.state[RANDOM_LTD_MIN_VALUE],
- RANDOM_LTD_MAX_VALUE: self.state[RANDOM_LTD_MAX_VALUE],
- }
- def load_state_dict(self, state_dict):
- self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = state_dict[RANDOM_LTD_CONSUMED_LAYER_TOKENS]
- self.state[RANDOM_LTD_CURR_STEP] = state_dict[RANDOM_LTD_CURR_STEP]
- self.state[RANDOM_LTD_CURRENT_VALUE] = state_dict[RANDOM_LTD_CURRENT_VALUE]
- self.state[RANDOM_LTD_MIN_VALUE] = state_dict[RANDOM_LTD_MIN_VALUE]
- self.state[RANDOM_LTD_MAX_VALUE] = state_dict[RANDOM_LTD_MAX_VALUE]
|