compress.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import re
  5. from .helper import compression_preparation, fix_compression, recursive_getattr, is_module_compressible
  6. from .config import get_compression_config
  7. from ..runtime.config_utils import dict_raise_error_on_duplicate_keys
  8. from .constants import *
  9. import os
  10. import json
  11. try:
  12. import neural_compressor as nc
  13. except ImportError as e:
  14. nc = None
  15. def check_deepspeed_config(config):
  16. if isinstance(config, dict):
  17. return config
  18. elif os.path.exists(config):
  19. return json.load(open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
  20. else:
  21. raise ValueError(
  22. f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}")
  23. def get_module_name(group_name, model, key_word, exist_module_name, mpu=None, verbose=True):
  24. '''
  25. get the associated module name from the model based on the key_word provided by users
  26. '''
  27. return_module_name = []
  28. for name, module in model.named_modules():
  29. module_check = is_module_compressible(module, mpu)
  30. if re.search(key_word, name) is not None and module_check:
  31. if name in exist_module_name and verbose:
  32. # logger.warning
  33. raise ValueError(
  34. f"{name} is already added to compression, please check your config file for {group_name}.")
  35. if name not in exist_module_name:
  36. exist_module_name.add(name)
  37. return_module_name.append(name)
  38. return return_module_name, exist_module_name
  39. def get_compress_methods(model, compress_methods, mpu=None):
  40. # extract the compression module for each method in compress_methods
  41. layer_added_compress_methods = []
  42. for method, method_content in compress_methods.items():
  43. if LAYER_REDUCTION in method:
  44. continue
  45. # for loop different methods, i.e., weight quantization, activation quantization etc
  46. exist_module_name = set()
  47. shared_parameters = method_content[SHARED_PARAMETERS] # get all the shared parameters
  48. for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
  49. # for loop different groups, i.e., weight quantization group 1, weight quantization group 2 etc
  50. module_name_list = []
  51. related_module_name_list = []
  52. if method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]:
  53. # this is used for head/row/channel pruning, if users provide the related module scope, we can shrink the layer dim for them
  54. # otherwise we just mask those as zeros
  55. for key_word, related_key_words in zip(method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE],
  56. method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]):
  57. module_name, exist_module_name = get_module_name(group_name,
  58. model,
  59. key_word,
  60. exist_module_name,
  61. mpu=mpu)
  62. module_name_list.append(module_name)
  63. tmp_related_module_name_list = []
  64. for rkw in related_key_words:
  65. # related key word can be a list, for instance the QKV for O matrix in Attention
  66. module_name, _ = get_module_name(group_name, model, rkw, set(), mpu=mpu)
  67. tmp_related_module_name_list.append(module_name)
  68. related_module_name_list.append(tmp_related_module_name_list)
  69. else:
  70. for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
  71. module_name, exist_module_name = get_module_name(group_name,
  72. model,
  73. key_word,
  74. exist_module_name,
  75. mpu=mpu)
  76. module_name_list.append(module_name)
  77. if module_name_list:
  78. # combine shared parameters with each group
  79. combined_method_parameters = {
  80. **(method_parameters.copy().pop(DIFFERENT_GROUPS_PARAMETERS)),
  81. **shared_parameters
  82. }
  83. compression_item = [module_name_list, related_module_name_list, {method: combined_method_parameters}]
  84. layer_added_compress_methods.append(compression_item)
  85. return layer_added_compress_methods
  86. def init_compression(model, deepspeed_config, teacher_model=None, mpu=None):
  87. """
  88. Compress a model: replace linear/conv2d layer with deepspeed compression-aware modules
  89. Args:
  90. model (`torch.nn.Module`)
  91. The model to compress.
  92. deepspeed_config (`DeepSpeedConfig`)
  93. The path of ds_config
  94. mpu
  95. The mpu module for Row/Column parallelism
  96. """
  97. compress_methods = get_compression_config(check_deepspeed_config(deepspeed_config))
  98. if hasattr(model, 'module'):
  99. c_model = model.module
  100. else:
  101. c_model = model
  102. # For layer reduction
  103. if compress_methods[LAYER_REDUCTION][LAYER_REDUCTION_ENABLED]:
  104. assert teacher_model is not None, "Teacher model is required for layer reduction"
  105. student_initialization(c_model, teacher_model, deepspeed_config)
  106. layer_added_compress_methods = get_compress_methods(c_model, compress_methods, mpu=mpu)
  107. compression_preparation(c_model, layer_added_compress_methods, mpu)
  108. # For sparse pruning snip_momentum method
  109. shared_parameters = compress_methods[SPARSE_PRUNING][SHARED_PARAMETERS]
  110. if shared_parameters[SPARSE_PRUNING_ENABLED] and \
  111. shared_parameters[SPARSE_PRUNING_METHOD] == SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:
  112. assert nc is not None, "please ensure the neural_compressor python package is installed by pip or conda if user wants to use snip_momentum sparse pruning"
  113. from .helper import generate_pruners, register_on_step_begin
  114. from nc import WeightPruningConfig
  115. config = WeightPruningConfig(target_sparsity=1 - shared_parameters[SPARSE_PRUNING_DENSE_RATIO],
  116. pattern=shared_parameters[SPARSE_PRUNING_BLOCK_PATTERN],
  117. pruning_frequency=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE],
  118. start_step=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET],
  119. end_step=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET_END],
  120. excluded_op_names=shared_parameters[SPARSE_PRUNING_EXCLUDED_MODULES])
  121. pruners = generate_pruners(config, c_model)
  122. c_model.pruners = pruners
  123. register_on_step_begin(c_model)
  124. return model
  125. def redundancy_clean(model, deepspeed_config, mpu=None):
  126. """
  127. Remove the redundancy of a model
  128. Args:
  129. model (`torch.nn.Module`)
  130. The model to compress.
  131. deepspeed_config (`DeepSpeedConfig`)
  132. The path of ds_config
  133. mpu
  134. The mpu module for Row/Column parallelism
  135. """
  136. compress_methods = get_compression_config(check_deepspeed_config(deepspeed_config))
  137. if hasattr(model, 'module'):
  138. c_model = model.module
  139. else:
  140. c_model = model
  141. layer_added_compress_methods_tmp = get_compress_methods(c_model, compress_methods, mpu=mpu)
  142. # sort methods
  143. order_list = [
  144. WEIGHT_QUANTIZATION, SPARSE_PRUNING, ROW_PRUNING, HEAD_PRUNING, CHANNEL_PRUNING, ACTIVATION_QUANTIZATION
  145. ]
  146. layer_added_compress_methods = sorted(layer_added_compress_methods_tmp,
  147. key=lambda x: order_list.index(list(x[2].keys())[0]))
  148. for module_name_lists, related_module_name_lists, compression_technique in layer_added_compress_methods:
  149. stored_mask = []
  150. need_mask = True if related_module_name_lists else False
  151. for i, mnl in enumerate(module_name_lists):
  152. for module_name in mnl:
  153. mask = fix_compression(c_model, module_name, compression_technique, dim_reduction=need_mask)
  154. if need_mask:
  155. stored_mask.append(mask)
  156. if need_mask:
  157. for rmnl in related_module_name_lists[i]:
  158. for j, module_name in enumerate(rmnl):
  159. mask = fix_compression(c_model,
  160. module_name,
  161. compression_technique,
  162. mask=stored_mask[j],
  163. dim_reduction=True)
  164. return model
  165. def student_initialization(student_model, teacher_model, deepspeed_config):
  166. '''
  167. Given a student model and a teacher model, select the
  168. Args:
  169. student_model (`torch.nn.Module`)
  170. The model we will update weight
  171. teacher_model (`torch.nn.Module`)
  172. The model guide the student to learn
  173. deepspeed_config (`DeepSpeedConfig`)
  174. The path of ds_config
  175. '''
  176. config = get_compression_config(check_deepspeed_config(deepspeed_config))
  177. compress_methods = config[LAYER_REDUCTION]
  178. module_name_prefix = compress_methods[MODULE_NAME_PREFIX]
  179. teacher_layer = compress_methods[TEACHER_LAYER]
  180. student_layer = [i for i in range(len(teacher_layer))]
  181. other_module_name = compress_methods[OTHER_MODULE_NAME]
  182. '''
  183. name_prefix (`str`)
  184. The prefix name before the layer #.
  185. Example 1: bert.encoder.layer, for BERT_base model's prefix name
  186. Example 2: transformer.h, for GPT-2 hugging face prefix name
  187. teacher_layer (`list of integers`)
  188. The layer of teacher will be used for student's reinitialization
  189. Example 1: [1,3,5,7,9], means we want to matches the 2nd/4th/6th/8th/10th layer of teacher to the first 5 layers of student
  190. student_layer (`list` or None)
  191. The layer of student need to be re-initialized
  192. Example 1: None, means we want to reinitialize all the layers
  193. Example 1: [0,1,2,3,4], means we want to reinitialize the first 5 layers
  194. other_module_name (`list of string`)
  195. The modules will be used for student's reinitialization
  196. Example 1: ['bert.pooler', 'bert.embeddings', 'classifier'], means we want to apply the weight in teacher's embedding/pooler/classier module to the student
  197. Example 2: ['transformer.w', 'transformer.ln_f', 'lm_head'], means we want to apply the weight in teacher's embedding layers module to the student
  198. Note that teacher_layer should matches student layer
  199. '''
  200. assert len(student_layer) == len(teacher_layer)
  201. for s_name, t_name in zip(student_layer, teacher_layer):
  202. s_module = recursive_getattr(student_model, module_name_prefix + '.' + str(s_name))
  203. t_module = recursive_getattr(teacher_model, module_name_prefix + '.' + str(t_name))
  204. for s_param, t_param in zip(s_module.parameters(), t_module.parameters()):
  205. s_param.data.copy_(t_param.data)
  206. for name in other_module_name:
  207. s_module = recursive_getattr(student_model, name)
  208. t_module = recursive_getattr(teacher_model, name)
  209. print(name)
  210. for s_param, t_param in zip(s_module.parameters(), t_module.parameters()):
  211. s_param.data.copy_(t_param.data)