helper.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import torch
  3. from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
  4. from .constants import *
  5. def recursive_getattr(model, module_name):
  6. """
  7. Recursively get the attribute of a module.
  8. Args:
  9. model (`torch.nn.Module`)
  10. The model to get the attribute from.
  11. module_name (`str`)
  12. The name of the module to get the attribute from.
  13. """
  14. split_list = module_name.split('.')
  15. output = model
  16. for name in split_list:
  17. output = getattr(output, name)
  18. return output
  19. def recursive_setattr(model, module_name, module):
  20. """
  21. Recursively set the attribute of a module.
  22. Args:
  23. model (`torch.nn.Module`)
  24. The model to set the attribute in.
  25. module_name (`str`)
  26. The name of the module to set the attribute in.
  27. module (`torch.nn.Module`)
  28. The module to set the attribute to.
  29. """
  30. split_list = module_name.split('.')
  31. output = model
  32. for name in split_list[:-1]:
  33. output = getattr(output, name)
  34. output.__setattr__(split_list[-1], module)
  35. def module_replacement(model, module_name, compression_technique=None, mpu=None):
  36. """
  37. Replace a module with a new module.
  38. Args:
  39. model (`torch.nn.Module`)
  40. The model to replace the module in.
  41. module_name (`str`)
  42. The name of the module to replace.
  43. compression_technique (`str`)
  44. The compression technique to use for the new module.
  45. """
  46. # Get the old module
  47. old_module = recursive_getattr(model, module_name)
  48. need_bias = False
  49. if hasattr(old_module, 'bias') and old_module.bias is not None:
  50. need_bias = True
  51. # Initialize the new module
  52. if isinstance(old_module,
  53. LinearLayer_Compress) or isinstance(old_module,
  54. torch.nn.Linear):
  55. if isinstance(old_module, LinearLayer_Compress):
  56. new_module = old_module
  57. else:
  58. new_module = LinearLayer_Compress(old_module.in_features,
  59. old_module.out_features,
  60. bias=need_bias).to(
  61. device=old_module.weight.device,
  62. dtype=old_module.weight.dtype)
  63. new_module.weight.data = old_module.weight.data
  64. if need_bias:
  65. new_module.bias.data = old_module.bias.data
  66. elif isinstance(old_module,
  67. Conv2dLayer_Compress) or isinstance(old_module,
  68. torch.nn.Conv2d):
  69. if isinstance(old_module, Conv2dLayer_Compress):
  70. new_module = old_module
  71. else:
  72. new_module = Conv2dLayer_Compress(old_module.in_channels, old_module.out_channels, old_module.kernel_size, old_module.stride, old_module.padding, \
  73. old_module.dilation, old_module.groups, need_bias, \
  74. old_module.padding_mode).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
  75. new_module.weight.data = old_module.weight.data
  76. if need_bias:
  77. new_module.bias.data = old_module.bias.data
  78. elif isinstance(old_module, torch.nn.BatchNorm2d):
  79. new_module = BNLayer_Compress(old_module.num_features,
  80. old_module.eps,
  81. old_module.momentum,
  82. old_module.affine,
  83. old_module.track_running_stats).to(
  84. old_module.weight.device,
  85. old_module.weight.dtype)
  86. new_module.weight.data = old_module.weight.data
  87. if need_bias:
  88. new_module.bias.data = old_module.bias.data
  89. new_module.running_mean.data = old_module.running_mean.data
  90. new_module.running_var.data = old_module.running_var.data
  91. elif isinstance(old_module,
  92. Embedding_Compress) or isinstance(old_module,
  93. torch.nn.Embedding):
  94. if isinstance(old_module, Embedding_Compress):
  95. new_module = old_module
  96. else:
  97. new_module = Embedding_Compress(old_module.num_embeddings, old_module.embedding_dim, old_module.padding_idx, old_module.max_norm, old_module.norm_type, \
  98. old_module.scale_grad_by_freq, old_module.sparse).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
  99. new_module.weight.data = old_module.weight.data
  100. elif mpu is not None and (isinstance(old_module,
  101. ColumnParallelLinear_Compress)
  102. or isinstance(old_module,
  103. mpu.ColumnParallelLinear)):
  104. if isinstance(old_module, ColumnParallelLinear_Compress):
  105. new_module = old_module
  106. else:
  107. new_module = ColumnParallelLinear_Compress(
  108. mpu,
  109. old_module.input_size,
  110. old_module.output_size,
  111. gather_output=old_module.gather_output,
  112. skip_bias_add=old_module.skip_bias_add,
  113. bias=need_bias).to(device=old_module.weight.device,
  114. dtype=old_module.weight.dtype)
  115. new_module.weight.data = old_module.weight.data
  116. if need_bias:
  117. new_module.bias.data = old_module.bias.data
  118. elif mpu is not None and (isinstance(old_module,
  119. RowParallelLinear_Compress)
  120. or isinstance(old_module,
  121. mpu.RowParallelLinear)):
  122. if isinstance(old_module, RowParallelLinear_Compress):
  123. new_module = old_module
  124. else:
  125. new_module = RowParallelLinear_Compress(
  126. mpu,
  127. old_module.input_size,
  128. old_module.output_size,
  129. input_is_parallel=old_module.input_is_parallel,
  130. skip_bias_add=old_module.skip_bias_add,
  131. bias=need_bias).to(device=old_module.weight.device,
  132. dtype=old_module.weight.dtype)
  133. new_module.weight.data = old_module.weight.data
  134. if need_bias:
  135. new_module.bias.data = old_module.bias.data
  136. else:
  137. new_module = None
  138. if compression_technique is not None:
  139. for k, v in compression_technique.items():
  140. if k == SPARSE_PRUNING:
  141. if v[SPARSE_PRUNING_ENABLED]:
  142. new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO],
  143. v[SPARSE_PRUNING_METHOD])
  144. elif k == ROW_PRUNING:
  145. if v[ROW_PRUNING_ENABLED]:
  146. new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO],
  147. v[ROW_PRUNING_METHOD])
  148. elif k == HEAD_PRUNING:
  149. if v[HEAD_PRUNING_ENABLED]:
  150. new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO],
  151. v[HEAD_PRUNING_METHOD],
  152. v[HEAD_PRUNING_NUM_HEADS])
  153. elif k == ACTIVATION_QUANTIZATION:
  154. if v[ACTIVATION_QUANTIZATION_ENABLED]:
  155. new_module.enable_activation_quantization(
  156. v[ACTIVATION_QUANTIZE_BITS],
  157. v[ACTIVATION_QUANTIZE_TYPE],
  158. v[ACTIVATION_QUANTIZE_RANGE])
  159. elif k == WEIGHT_QUANTIZATION:
  160. if v[WEIGHT_QUANTIZE_ENABLED]:
  161. new_module.enable_weight_quantization(
  162. v[WEIGHT_QUANTIZE_START_BITS],
  163. v[WEIGHT_QUANTIZE_TARGET_BITS],
  164. v[WEIGHT_QUANTIZATION_PERIOD],
  165. v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
  166. v[WEIGHT_QUANTIZE_TYPE],
  167. v[WEIGHT_QUANTIZE_GROUPS])
  168. elif k == CHANNEL_PRUNING:
  169. if v[CHANNEL_PRUNING_ENABLED]:
  170. new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO],
  171. v[CHANNEL_PRUNING_METHOD])
  172. else:
  173. raise NotImplementedError(
  174. 'Compression technique {} is not implemented'.format(k))
  175. # Replace the old module with the new one
  176. recursive_setattr(model, module_name, new_module)
  177. def is_module_compressible(module, mpu=None):
  178. ret = isinstance(module, torch.nn.Linear) or \
  179. isinstance(module, torch.nn.Conv2d) or \
  180. isinstance(module, torch.nn.Embedding) or \
  181. isinstance(module, torch.nn.BatchNorm2d)
  182. if mpu is not None:
  183. ret = ret or isinstance(module,
  184. mpu.RowParallelLinear) or isinstance(
  185. module,
  186. mpu.ColumnParallelLinear)
  187. return ret
  188. def compression_preparation(model, compression_techinique_list, mpu):
  189. """
  190. Prepare the compression techniques of a model.
  191. Args:
  192. model (`torch.nn.Module`)
  193. The model to prepare the compression techniques of.
  194. compression_techinique_list (`list`)
  195. The list of compression techniques to prepare the model to.
  196. list[]
  197. """
  198. # Here we first replace all module with our linear wrapper
  199. for module_name, module in model.named_modules():
  200. if is_module_compressible(module, mpu):
  201. module_replacement(model, module_name, mpu=mpu)
  202. for module_name_lists, _, compression_technique in compression_techinique_list:
  203. for mnl in module_name_lists:
  204. for module_name in mnl:
  205. module_replacement(model, module_name, compression_technique)
  206. return model
  207. def fix_compression(model,
  208. module_name,
  209. compression_technique,
  210. mask=None,
  211. dim_reduction=False):
  212. """
  213. Fix the compression technique of a module.
  214. Args:
  215. model (`torch.nn.Module`)
  216. The model to fix the compression technique of.
  217. module_name (`str`)
  218. The name of the module to fix the compression technique of.
  219. compression_technique (`str`)
  220. The compression technique to fix the module to.
  221. """
  222. # Here we can make things much simpler by just replacing the module
  223. module = recursive_getattr(model, module_name)
  224. for k, v in compression_technique.items():
  225. if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[
  226. WEIGHT_QUANTIZE_ENABLED]:
  227. return module.fix_weight_quantization()
  228. elif k == SPARSE_PRUNING and v[SPARSE_PRUNING_ENABLED]:
  229. return module.fix_sparse_pruning_helper()
  230. elif k == ROW_PRUNING and (v[ROW_PRUNING_ENABLED] or mask is not None):
  231. return module.fix_row_col_pruning_helper(mask, dim_reduction=dim_reduction)
  232. elif k == HEAD_PRUNING and (v[HEAD_PRUNING_ENABLED] or mask is not None):
  233. return module.fix_head_pruning_helper(mask,
  234. v[HEAD_PRUNING_NUM_HEADS],
  235. dim_reduction=dim_reduction)
  236. elif k == CHANNEL_PRUNING and (v[CHANNEL_PRUNING_ENABLED] or mask is not None):
  237. return module.fix_channel_pruning_helper(mask, dim_reduction=dim_reduction)
  238. def convert_conv1d_to_linear(model, convert_type):
  239. '''
  240. This is a help function to convert conv1d to linear (e.g., convert GPT2 from HF)
  241. '''
  242. if hasattr(model, 'module'):
  243. c_model = model.module
  244. else:
  245. c_model = model
  246. for name, module in c_model.named_modules():
  247. if isinstance(module, convert_type):
  248. old_module = recursive_getattr(c_model, name)
  249. new_module = torch.nn.Linear(
  250. old_module.weight.data.size(0),
  251. old_module.weight.data.size(1),
  252. bias=True if old_module.bias is not None else False)
  253. new_module.weight.data = old_module.weight.data.t().contiguous()
  254. if new_module.bias is not None:
  255. new_module.bias.data = old_module.bias.data.view(-1)
  256. recursive_setattr(c_model, name, new_module)
  257. return model