scheduler.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .compress import get_module_name
  5. from .constants import *
  6. from .helper import recursive_getattr
  7. from deepspeed.utils import logger
  8. class compression_scheduler():
  9. '''
  10. Used to schedule different compression methods
  11. '''
  12. def __init__(self, model, compression_config):
  13. self.model = model
  14. self.compression_config = compression_config
  15. self.make_init()
  16. self.training_steps = 0
  17. self.weight_quantization_enabled = False
  18. self.verbose = {
  19. WEIGHT_QUANTIZATION: False,
  20. ACTIVATION_QUANTIZATION: False,
  21. SPARSE_PRUNING: False,
  22. HEAD_PRUNING: False,
  23. ROW_PRUNING: False,
  24. CHANNEL_PRUNING: False
  25. }
  26. def make_init(self):
  27. self.different_compression_methods = {}
  28. for method, method_content in self.compression_config.items():
  29. if LAYER_REDUCTION in method:
  30. continue
  31. self.different_compression_methods[method] = {
  32. TECHNIQUE_ENABLED: False,
  33. SHARED_PARAMETERS: None,
  34. DIFFERENT_GROUPS: []
  35. }
  36. exist_module_name = set()
  37. shared_parameters = method_content[SHARED_PARAMETERS]
  38. self.different_compression_methods[method][TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED]
  39. self.different_compression_methods[method][SHARED_PARAMETERS] = shared_parameters
  40. for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
  41. module_name_list = []
  42. for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
  43. module_name, exist_module_name = get_module_name(group_name,
  44. self.model,
  45. key_word,
  46. exist_module_name,
  47. verbose=False)
  48. module_name_list.extend(module_name)
  49. if module_name_list:
  50. self.different_compression_methods[method][DIFFERENT_GROUPS].append(
  51. [group_name, module_name_list,
  52. method_parameters.copy().pop('params')])
  53. def check_weight_quantization(self):
  54. # check weight quantization
  55. wq = self.different_compression_methods[WEIGHT_QUANTIZATION]
  56. if not wq[TECHNIQUE_ENABLED]:
  57. return
  58. else:
  59. shared_parameters = wq[SHARED_PARAMETERS]
  60. if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
  61. for group_name, module_name_list, method_parameters in wq[DIFFERENT_GROUPS]:
  62. for module_name in module_name_list:
  63. module = recursive_getattr(self.model, module_name)
  64. module.weight_quantization_enabled = True
  65. if not self.verbose[WEIGHT_QUANTIZATION]:
  66. logger.info(f'Weight quantization is enabled at step {self.training_steps}')
  67. self.weight_quantization_enabled = True
  68. self.verbose[WEIGHT_QUANTIZATION] = True
  69. def check_activation_quantization(self):
  70. # check activation quantization
  71. aq = self.different_compression_methods[ACTIVATION_QUANTIZATION]
  72. if not aq[TECHNIQUE_ENABLED]:
  73. return
  74. else:
  75. shared_parameters = aq[SHARED_PARAMETERS]
  76. if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
  77. for group_name, module_name_list, method_parameters in aq[DIFFERENT_GROUPS]:
  78. for module_name in module_name_list:
  79. module = recursive_getattr(self.model, module_name)
  80. module.activation_quantization_enabled = True
  81. if not self.verbose[ACTIVATION_QUANTIZATION]:
  82. logger.info(f'Activation quantization is enabled at step {self.training_steps}')
  83. self.verbose[ACTIVATION_QUANTIZATION] = True
  84. def check_sparse_pruning(self):
  85. # check sparse pruning
  86. sp = self.different_compression_methods[SPARSE_PRUNING]
  87. if not sp[TECHNIQUE_ENABLED]:
  88. return
  89. else:
  90. shared_parameters = sp[SHARED_PARAMETERS]
  91. if shared_parameters[TECHNIQUE_SCHEDULE_OFFSET] <= self.training_steps <= shared_parameters[
  92. TECHNIQUE_SCHEDULE_OFFSET_END]:
  93. for group_name, module_name_list, method_parameters in sp[DIFFERENT_GROUPS]:
  94. for module_name in module_name_list:
  95. module = recursive_getattr(self.model, module_name)
  96. module.sparse_pruning_enabled = True
  97. if not self.verbose[SPARSE_PRUNING]:
  98. logger.info(f'Sparse pruning is enabled at step {self.training_steps}')
  99. self.verbose[SPARSE_PRUNING] = True
  100. def check_head_pruning(self):
  101. # check head pruning
  102. hp = self.different_compression_methods[HEAD_PRUNING]
  103. if not hp[TECHNIQUE_ENABLED]:
  104. return
  105. else:
  106. shared_parameters = hp[SHARED_PARAMETERS]
  107. if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
  108. for group_name, module_name_list, method_parameters in hp[DIFFERENT_GROUPS]:
  109. for module_name in module_name_list:
  110. module = recursive_getattr(self.model, module_name)
  111. module.head_pruning_enabled = True
  112. if not self.verbose[HEAD_PRUNING]:
  113. logger.info(f'Head pruning is enabled at step {self.training_steps}')
  114. self.verbose[HEAD_PRUNING] = True
  115. def check_row_pruning(self):
  116. # check row pruning
  117. rp = self.different_compression_methods[ROW_PRUNING]
  118. if not rp[TECHNIQUE_ENABLED]:
  119. return
  120. else:
  121. shared_parameters = rp[SHARED_PARAMETERS]
  122. if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
  123. for group_name, module_name_list, method_parameters in rp[DIFFERENT_GROUPS]:
  124. for module_name in module_name_list:
  125. module = recursive_getattr(self.model, module_name)
  126. module.row_pruning_enabled = True
  127. if not self.verbose[ROW_PRUNING]:
  128. logger.info(f'Row pruning is enabled at step {self.training_steps}')
  129. self.verbose[ROW_PRUNING] = True
  130. def check_channel_pruning(self):
  131. # check channel pruning
  132. cp = self.different_compression_methods[CHANNEL_PRUNING]
  133. if not cp[TECHNIQUE_ENABLED]:
  134. return
  135. else:
  136. shared_parameters = cp[SHARED_PARAMETERS]
  137. if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
  138. for group_name, module_name_list, method_parameters in cp[DIFFERENT_GROUPS]:
  139. for module_name in module_name_list:
  140. module = recursive_getattr(self.model, module_name)
  141. module.channel_pruning_enabled = True
  142. if not self.verbose[CHANNEL_PRUNING]:
  143. logger.info(f'Channel pruning is enabled at step {self.training_steps}')
  144. self.verbose[CHANNEL_PRUNING] = True
  145. def check_all_modules(self):
  146. # check all different compression methods we have
  147. self.check_weight_quantization()
  148. self.check_activation_quantization()
  149. self.check_sparse_pruning()
  150. self.check_head_pruning()
  151. self.check_row_pruning()
  152. self.check_channel_pruning()
  153. def step(self, step_zero_check=False):
  154. if not step_zero_check:
  155. self.training_steps += 1
  156. self.check_all_modules()