helper.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
  6. from .constants import *
  7. from deepspeed.utils import logger
  8. try:
  9. from neural_compressor.compression import pruner as nc_pruner
  10. except ImportError as e:
  11. nc_pruner = None
  12. def recursive_getattr(model, module_name):
  13. """
  14. Recursively get the attribute of a module.
  15. Args:
  16. model (`torch.nn.Module`)
  17. The model to get the attribute from.
  18. module_name (`str`)
  19. The name of the module to get the attribute from.
  20. """
  21. split_list = module_name.split('.')
  22. output = model
  23. for name in split_list:
  24. output = getattr(output, name)
  25. return output
  26. def recursive_setattr(model, module_name, module):
  27. """
  28. Recursively set the attribute of a module.
  29. Args:
  30. model (`torch.nn.Module`)
  31. The model to set the attribute in.
  32. module_name (`str`)
  33. The name of the module to set the attribute in.
  34. module (`torch.nn.Module`)
  35. The module to set the attribute to.
  36. """
  37. split_list = module_name.split('.')
  38. output = model
  39. for name in split_list[:-1]:
  40. output = getattr(output, name)
  41. output.__setattr__(split_list[-1], module)
  42. def module_replacement(model, module_name, compression_technique=None, mpu=None):
  43. """
  44. Replace a module with a new module.
  45. Args:
  46. model (`torch.nn.Module`)
  47. The model to replace the module in.
  48. module_name (`str`)
  49. The name of the module to replace.
  50. compression_technique (`str`)
  51. The compression technique to use for the new module.
  52. """
  53. # Get the old module
  54. old_module = recursive_getattr(model, module_name)
  55. need_bias = False
  56. if hasattr(old_module, 'bias') and old_module.bias is not None:
  57. need_bias = True
  58. # Initialize the new module
  59. if isinstance(old_module, LinearLayer_Compress) or isinstance(old_module, torch.nn.Linear):
  60. if isinstance(old_module, LinearLayer_Compress):
  61. new_module = old_module
  62. else:
  63. new_module = LinearLayer_Compress(old_module.in_features, old_module.out_features,
  64. bias=need_bias).to(device=old_module.weight.device,
  65. dtype=old_module.weight.dtype)
  66. new_module.weight.data = old_module.weight.data
  67. if need_bias:
  68. new_module.bias.data = old_module.bias.data
  69. elif isinstance(old_module, Conv2dLayer_Compress) or isinstance(old_module, torch.nn.Conv2d):
  70. if isinstance(old_module, Conv2dLayer_Compress):
  71. new_module = old_module
  72. else:
  73. new_module = Conv2dLayer_Compress(old_module.in_channels, old_module.out_channels, old_module.kernel_size, old_module.stride, old_module.padding, \
  74. old_module.dilation, old_module.groups, need_bias, \
  75. old_module.padding_mode).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
  76. new_module.weight.data = old_module.weight.data
  77. if need_bias:
  78. new_module.bias.data = old_module.bias.data
  79. elif isinstance(old_module, torch.nn.BatchNorm2d):
  80. new_module = BNLayer_Compress(old_module.num_features, old_module.eps, old_module.momentum, old_module.affine,
  81. old_module.track_running_stats).to(old_module.weight.device,
  82. old_module.weight.dtype)
  83. new_module.weight.data = old_module.weight.data
  84. if need_bias:
  85. new_module.bias.data = old_module.bias.data
  86. new_module.running_mean.data = old_module.running_mean.data
  87. new_module.running_var.data = old_module.running_var.data
  88. elif isinstance(old_module, Embedding_Compress) or isinstance(old_module, torch.nn.Embedding):
  89. if isinstance(old_module, Embedding_Compress):
  90. new_module = old_module
  91. else:
  92. new_module = Embedding_Compress(old_module.num_embeddings, old_module.embedding_dim, old_module.padding_idx, old_module.max_norm, old_module.norm_type, \
  93. old_module.scale_grad_by_freq, old_module.sparse).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
  94. new_module.weight.data = old_module.weight.data
  95. elif mpu is not None and (isinstance(old_module, ColumnParallelLinear_Compress)
  96. or isinstance(old_module, mpu.ColumnParallelLinear)):
  97. if isinstance(old_module, ColumnParallelLinear_Compress):
  98. new_module = old_module
  99. else:
  100. new_module = ColumnParallelLinear_Compress(mpu,
  101. old_module.input_size,
  102. old_module.output_size,
  103. gather_output=old_module.gather_output,
  104. skip_bias_add=old_module.skip_bias_add,
  105. bias=need_bias).to(device=old_module.weight.device,
  106. dtype=old_module.weight.dtype)
  107. new_module.weight.data = old_module.weight.data
  108. if need_bias:
  109. new_module.bias.data = old_module.bias.data
  110. elif mpu is not None and (isinstance(old_module, RowParallelLinear_Compress)
  111. or isinstance(old_module, mpu.RowParallelLinear)):
  112. if isinstance(old_module, RowParallelLinear_Compress):
  113. new_module = old_module
  114. else:
  115. new_module = RowParallelLinear_Compress(mpu,
  116. old_module.input_size,
  117. old_module.output_size,
  118. input_is_parallel=old_module.input_is_parallel,
  119. skip_bias_add=old_module.skip_bias_add,
  120. bias=need_bias).to(device=old_module.weight.device,
  121. dtype=old_module.weight.dtype)
  122. new_module.weight.data = old_module.weight.data
  123. if need_bias:
  124. new_module.bias.data = old_module.bias.data
  125. else:
  126. new_module = None
  127. if compression_technique is not None:
  128. for k, v in compression_technique.items():
  129. if k == SPARSE_PRUNING:
  130. if v[SPARSE_PRUNING_ENABLED]:
  131. new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO], v[SPARSE_PRUNING_METHOD])
  132. elif k == ROW_PRUNING:
  133. if v[ROW_PRUNING_ENABLED]:
  134. new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO], v[ROW_PRUNING_METHOD])
  135. elif k == HEAD_PRUNING:
  136. if v[HEAD_PRUNING_ENABLED]:
  137. new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO], v[HEAD_PRUNING_METHOD],
  138. v[HEAD_PRUNING_NUM_HEADS])
  139. elif k == ACTIVATION_QUANTIZATION:
  140. if v[ACTIVATION_QUANTIZATION_ENABLED]:
  141. new_module.enable_activation_quantization(v[ACTIVATION_QUANTIZE_BITS], v[ACTIVATION_QUANTIZE_TYPE],
  142. v[ACTIVATION_QUANTIZE_RANGE])
  143. elif k == WEIGHT_QUANTIZATION:
  144. if v[WEIGHT_QUANTIZE_ENABLED]:
  145. new_module.enable_weight_quantization(v[WEIGHT_QUANTIZE_START_BITS],
  146. v[WEIGHT_QUANTIZE_TARGET_BITS],
  147. v[WEIGHT_QUANTIZATION_PERIOD],
  148. v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
  149. v[WEIGHT_QUANTIZE_TYPE], v[WEIGHT_QUANTIZE_GROUPS])
  150. elif k == CHANNEL_PRUNING:
  151. if v[CHANNEL_PRUNING_ENABLED]:
  152. new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO], v[CHANNEL_PRUNING_METHOD])
  153. else:
  154. raise NotImplementedError('Compression technique {} is not implemented'.format(k))
  155. # Replace the old module with the new one
  156. recursive_setattr(model, module_name, new_module)
  157. def is_module_compressible(module, mpu=None):
  158. ret = isinstance(module, torch.nn.Linear) or \
  159. isinstance(module, torch.nn.Conv2d) or \
  160. isinstance(module, torch.nn.Embedding) or \
  161. isinstance(module, torch.nn.BatchNorm2d)
  162. if mpu is not None:
  163. ret = ret or isinstance(module, mpu.RowParallelLinear) or isinstance(module, mpu.ColumnParallelLinear)
  164. return ret
  165. def compression_preparation(model, compression_technique_list, mpu):
  166. """
  167. Prepare the compression techniques of a model.
  168. Args:
  169. model (`torch.nn.Module`)
  170. The model to prepare the compression techniques of.
  171. compression_technique_list (`list`)
  172. The list of compression techniques to prepare the model to.
  173. list[]
  174. """
  175. # Here we first replace all module with our linear wrapper
  176. for module_name, module in model.named_modules():
  177. if is_module_compressible(module, mpu):
  178. module_replacement(model, module_name, mpu=mpu)
  179. for module_name_lists, _, compression_technique in compression_technique_list:
  180. for mnl in module_name_lists:
  181. for module_name in mnl:
  182. module_replacement(model, module_name, compression_technique)
  183. return model
  184. def fix_compression(model, module_name, compression_technique, mask=None, dim_reduction=False):
  185. """
  186. Fix the compression technique of a module.
  187. Args:
  188. model (`torch.nn.Module`)
  189. The model to fix the compression technique of.
  190. module_name (`str`)
  191. The name of the module to fix the compression technique of.
  192. compression_technique (`str`)
  193. The compression technique to fix the module to.
  194. """
  195. # Here we can make things much simpler by just replacing the module
  196. module = recursive_getattr(model, module_name)
  197. for k, v in compression_technique.items():
  198. if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[WEIGHT_QUANTIZE_ENABLED]:
  199. return module.fix_weight_quantization()
  200. elif k == SPARSE_PRUNING and v[SPARSE_PRUNING_ENABLED]:
  201. return module.fix_sparse_pruning_helper()
  202. elif k == ROW_PRUNING and (v[ROW_PRUNING_ENABLED] or mask is not None):
  203. return module.fix_row_col_pruning_helper(mask, dim_reduction=dim_reduction)
  204. elif k == HEAD_PRUNING and (v[HEAD_PRUNING_ENABLED] or mask is not None):
  205. return module.fix_head_pruning_helper(mask, v[HEAD_PRUNING_NUM_HEADS], dim_reduction=dim_reduction)
  206. elif k == CHANNEL_PRUNING and (v[CHANNEL_PRUNING_ENABLED] or mask is not None):
  207. return module.fix_channel_pruning_helper(mask, dim_reduction=dim_reduction)
  208. def convert_conv1d_to_linear(model, convert_type):
  209. '''
  210. This is a help function to convert conv1d to linear (e.g., convert GPT2 from HF)
  211. '''
  212. if hasattr(model, 'module'):
  213. c_model = model.module
  214. else:
  215. c_model = model
  216. for name, module in c_model.named_modules():
  217. if isinstance(module, convert_type):
  218. old_module = recursive_getattr(c_model, name)
  219. new_module = torch.nn.Linear(old_module.weight.data.size(0),
  220. old_module.weight.data.size(1),
  221. bias=True if old_module.bias is not None else False)
  222. new_module.weight.data = old_module.weight.data.t().contiguous()
  223. if new_module.bias is not None:
  224. new_module.bias.data = old_module.bias.data.view(-1)
  225. recursive_setattr(c_model, name, new_module)
  226. return model
  227. def generate_pruners(config, model):
  228. """Generate pruners.
  229. Args:
  230. config (`neural_compressor.WeightPruningConfig`)
  231. The object to the class WeightPruningConfig.
  232. model (`torch.nn.module`)
  233. The torch module object to be pruned.
  234. """
  235. assert nc_pruner is not None, "please ensure the neural_compressor python package is installed by pip or conda if user wants to use snip_momentum sparse pruning"
  236. from nc_pruner.utils import process_config, parse_to_prune
  237. from nc_pruner.pruners import get_pruner
  238. assert isinstance(model, torch.nn.Module)
  239. pruners_info = process_config(config)
  240. pruners = []
  241. for info in pruners_info:
  242. modules = parse_to_prune(info, model)
  243. if modules == {}:
  244. logger.warning("one pruner hooks no layers, please have a check")
  245. pruners.append(get_pruner(info, modules))
  246. info['modules'] = [key for key in modules.keys()]
  247. info['len_of_modules'] = len(info['modules'])
  248. logger.info(info)
  249. return pruners
  250. def register_on_step_begin(model):
  251. """Mount on_step_begin to the model.
  252. Args:
  253. model (`torch.nn.module`)
  254. The torch module object to be pruned.
  255. """
  256. def hook(module, input):
  257. for pruner in module.pruners:
  258. pruner.on_step_begin(0)
  259. hook_handle = model.register_forward_pre_hook(hook)
  260. return hook_handle
  261. def rewrite_optimizer_step(opt: torch.optim.Optimizer):
  262. """Mount on_before/after_optimizer_step to the optimizer.
  263. Args:
  264. model (`torch.opt.Optimizer`)
  265. The torch optimizer object to be hooked.
  266. """
  267. def new_step(self, closure=None):
  268. if hasattr(self, "pruners"):
  269. for pruner in self.pruners:
  270. pruner.on_before_optimizer_step()
  271. if closure is not None:
  272. res = self.orig_step(closure)
  273. else:
  274. res = self.orig_step()
  275. if hasattr(self, "pruners"):
  276. for pruner in self.pruners:
  277. pruner.on_after_optimizer_step()
  278. return res
  279. opt.orig_step = opt.step
  280. import types
  281. opt.step = types.MethodType(new_step, opt)
  282. return opt