scheduler.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import math
  5. from deepspeed.utils import logger
  6. # from deepspeed.runtime.lr_schedules import WarmupLR
  7. from ..constants import *
  8. #####based on the paper random-ltd: https://arxiv.org/abs/2211.11586
  9. class BaseScheduler(object):
  10. def __init__(self):
  11. self.state = {}
  12. def __fixed_root_get_value(self, global_steps, root_degree=None):
  13. s_state = self.state[RANDOM_LTD_SCHEDULE_CONFIG]
  14. if root_degree is None:
  15. root_degree = s_state['root_degree']
  16. next_seq = (float(global_steps) / s_state[RANDOM_LTD_REQUIRE_STEP])**(1.0 / root_degree)
  17. next_seq = math.floor(next_seq * (self.state[RANDOM_LTD_MAX_VALUE] - self.state[RANDOM_LTD_MIN_VALUE]) +
  18. self.state[RANDOM_LTD_MIN_VALUE])
  19. next_seq -= (next_seq % s_state[RANDOM_LTD_INCREASE_STEP])
  20. next_seq = min(next_seq, self.state[RANDOM_LTD_MAX_VALUE])
  21. return next_seq
  22. def get_value(self, global_steps):
  23. if self.state[RANDOM_LTD_SCHEDULER_TYPE] == 'fixed_linear':
  24. return self.__fixed_root_get_value(global_steps, 1)
  25. else:
  26. raise RuntimeError('Unsupported random LTD schedule type')
  27. class RandomLTDScheduler(BaseScheduler):
  28. def __init__(self, config):
  29. super().__init__()
  30. self.model_layer_num = config[RANDOM_LTD_TOTAL_LAYER_NUM]
  31. self.random_ltd_layer_num = config[RANDOM_LTD_LAYER_NUM]
  32. self.config_schedule = config[RANDOM_LTD_SCHEDULER]
  33. self.global_batch_size = config[RANDOM_LTD_GLOBAL_BATCH_SIZE]
  34. self.reset_to_init()
  35. if config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]:
  36. logger.warning("**********Work In Progress************")
  37. raise NotImplementedError
  38. self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = 0
  39. # self.first_step = True
  40. def get_total_layer_tokens(self, train_iters):
  41. for step in range(train_iters):
  42. self.update_seq(step)
  43. return self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS]
  44. def reset_to_init(self):
  45. if self.config_schedule is not None:
  46. self.state[RANDOM_LTD_MIN_VALUE] = self.config_schedule[RANDOM_LTD_MIN_VALUE]
  47. self.state[RANDOM_LTD_MAX_VALUE] = self.config_schedule[RANDOM_LTD_MAX_VALUE]
  48. self.state[RANDOM_LTD_CURRENT_VALUE] = self.config_schedule[RANDOM_LTD_MIN_VALUE]
  49. self.state[RANDOM_LTD_SCHEDULE_CONFIG] = self.config_schedule[RANDOM_LTD_SCHEDULE_CONFIG]
  50. self.state[RANDOM_LTD_SCHEDULER_TYPE] = self.config_schedule[RANDOM_LTD_SCHEDULER_TYPE]
  51. self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = 0
  52. self.state[RANDOM_LTD_CURR_STEP] = -1
  53. def get_current_seq(self):
  54. return self.state[RANDOM_LTD_CURRENT_VALUE]
  55. def set_current_seq(self, seq_length):
  56. self.state[RANDOM_LTD_CURRENT_VALUE] = seq_length
  57. def get_random_ltd_layer_num(self):
  58. return self.random_ltd_layer_num
  59. def get_state(self):
  60. return self.state
  61. def set_state(self, state):
  62. self.state = state
  63. def update_seq(self, global_steps):
  64. if self.state[RANDOM_LTD_CURRENT_VALUE] < self.state[RANDOM_LTD_MAX_VALUE]:
  65. self.state[RANDOM_LTD_CURRENT_VALUE] = self.get_value(global_steps)
  66. if global_steps != self.state[RANDOM_LTD_CURR_STEP]:
  67. self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] += self.global_batch_size*(self.state[RANDOM_LTD_CURRENT_VALUE] * self.random_ltd_layer_num \
  68. + self.state[RANDOM_LTD_MAX_VALUE] * (self.model_layer_num - self.random_ltd_layer_num))
  69. self.state[RANDOM_LTD_CURR_STEP] = global_steps
  70. def state_dict(self):
  71. return {
  72. RANDOM_LTD_CONSUMED_LAYER_TOKENS: self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS],
  73. RANDOM_LTD_CURR_STEP: self.state[RANDOM_LTD_CURR_STEP],
  74. RANDOM_LTD_CURRENT_VALUE: self.state[RANDOM_LTD_CURRENT_VALUE],
  75. RANDOM_LTD_MIN_VALUE: self.state[RANDOM_LTD_MIN_VALUE],
  76. RANDOM_LTD_MAX_VALUE: self.state[RANDOM_LTD_MAX_VALUE],
  77. }
  78. def load_state_dict(self, state_dict):
  79. self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = state_dict[RANDOM_LTD_CONSUMED_LAYER_TOKENS]
  80. self.state[RANDOM_LTD_CURR_STEP] = state_dict[RANDOM_LTD_CURR_STEP]
  81. self.state[RANDOM_LTD_CURRENT_VALUE] = state_dict[RANDOM_LTD_CURRENT_VALUE]
  82. self.state[RANDOM_LTD_MIN_VALUE] = state_dict[RANDOM_LTD_MIN_VALUE]
  83. self.state[RANDOM_LTD_MAX_VALUE] = state_dict[RANDOM_LTD_MAX_VALUE]