basic_layer.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import math
  6. from torch import nn
  7. from torch.nn import init
  8. import deepspeed.comm as dist
  9. from .utils import TopKBinarizer, SymQuantizer, AsymQuantizer, TernaryQuantizer, BinaryQuantizer
  10. from deepspeed.utils import logger
  11. g_mpu = None
  12. class QuantAct(nn.Module):
  13. """
  14. Class to quantize given activations. Note that when using this function, the input activation quantization range will be fixed for all
  15. tokens/images for inference. This generally will affect some accuracy but achieve better latency performance.
  16. Parameters:
  17. ----------
  18. act_range_momentum : float, default 0.95
  19. Momentum for updating the activation quantization range.
  20. quant_mode : str, default 'symmetric'
  21. """
  22. def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'):
  23. super(QuantAct, self).__init__()
  24. self.act_range_momentum = act_range_momentum
  25. self.quant_mode = quant_mode
  26. if quant_mode == 'symmetric':
  27. self.act_function = SymQuantizer.apply
  28. else:
  29. self.act_function = AsymQuantizer.apply
  30. self.register_buffer('x_min_max', torch.zeros(2))
  31. def forward(self, x, num_bits, *args):
  32. """
  33. x: the activation that we need to quantize
  34. num_bits: the number of bits we need to quantize the activation to
  35. *args: some extra arguments that are useless but needed for align with the interface of other quantization functions
  36. """
  37. if self.training:
  38. x_min = x.data.min()
  39. x_max = x.data.max()
  40. # Initialization
  41. if self.x_min_max[0] == self.x_min_max[1]:
  42. self.x_min_max[0] = x_min
  43. self.x_min_max[1] = x_max
  44. # if do not need momentum, please set self.act_range_momentum = 0
  45. self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
  46. self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
  47. x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1])
  48. return x_q
  49. class Embedding_Compress(nn.Embedding):
  50. def __init__(self, *kargs):
  51. super(Embedding_Compress, self).__init__(*kargs)
  52. self.weight.start_bits = None
  53. self.weight.target_bits = None
  54. self.weight.q_period = None
  55. self.weight_quantization_enabled_in_forward = False
  56. self.weight_quantization_enabled = False
  57. def extra_repr(self):
  58. return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format(
  59. self.num_embeddings, self.embedding_dim, self.weight.target_bits)
  60. def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
  61. weight_quantization_enabled_in_forward, quantization_type, num_groups):
  62. self.weight.start_bits = start_bits
  63. self.weight.target_bits = target_bits
  64. self.weight.q_period = quantization_period
  65. self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
  66. if self.weight_quantization_enabled_in_forward:
  67. logger.warning(
  68. "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
  69. )
  70. if self.weight.target_bits >= 3:
  71. if quantization_type == 'symmetric':
  72. self.weight_quantizer = SymQuantizer.apply
  73. else:
  74. self.weight_quantizer = AsymQuantizer.apply
  75. elif self.weight.target_bits == 2:
  76. assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
  77. self.weight_quantizer = TernaryQuantizer.apply
  78. elif self.weight.target_bits == 1:
  79. assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
  80. self.weight_quantizer = BinaryQuantizer.apply
  81. # for embedding, we always use token-wise quantization
  82. self.weight_quantize_num_groups = self.weight.size(0)
  83. def fix_weight_quantization(self):
  84. self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
  85. self.weight_quantize_num_groups).data
  86. self.weight_quantization_enabled_in_forward = False
  87. return None
  88. def forward(self, input):
  89. if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
  90. weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
  91. self.weight_quantize_num_groups)
  92. else:
  93. weight = self.weight
  94. out = nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type,
  95. self.scale_grad_by_freq, self.sparse)
  96. return out
  97. class LinearLayer_Compress(nn.Linear):
  98. """
  99. Linear layer with compression.
  100. """
  101. def __init__(self, *kargs, bias=True):
  102. super(LinearLayer_Compress, self).__init__(*kargs, bias=bias)
  103. self.sparse_pruning_method = None
  104. self.row_pruning_method = None
  105. self.head_pruning_method = None
  106. self.activation_quantization_method = None
  107. self.weight.start_bits = None
  108. self.weight.target_bits = None
  109. self.weight.q_period = None
  110. self.weight_quantization_enabled_in_forward = False
  111. self.weight_quantization_enabled = False
  112. self.sparse_pruning_enabled = False
  113. self.row_pruning_enabled = False
  114. self.head_pruning_enabled = False
  115. self.activation_quantization_enabled = False
  116. def extra_repr(self):
  117. return 'in_features={}, out_features={}, bias={}, sparse pruning={}, row pruning={}, head pruning={}, activation quantization={}, weight_quantization={}'.format(
  118. self.in_features, self.out_features, self.bias is not None, self.sparse_pruning_method is not None, \
  119. self.row_pruning_method is not None, self.head_pruning_method is not None, self.activation_quantization_method is not None, self.weight.target_bits)
  120. def enable_sparse_pruning(self, ratio, method):
  121. # Here, we support two cases: L1 norm based pruning and topk based pruning
  122. self.sparse_pruning_ratio = ratio
  123. self.sparse_pruning_method = method
  124. if method == 'l1':
  125. weight_norm = torch.abs(self.weight.data)
  126. mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
  127. mask = mask.view(self.weight.size())
  128. mask = mask.to(self.weight.device)
  129. elif method == 'topk':
  130. self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
  131. self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
  132. init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
  133. mask = None
  134. else:
  135. raise NotImplementedError
  136. self.register_buffer('sparse_pruning_mask', mask)
  137. def enable_row_pruning(self, ratio, method):
  138. # Here, we support two cases: L1 norm based pruning and topk based pruning
  139. self.row_pruning_ratio = ratio
  140. self.row_pruning_method = method
  141. if method == 'l1':
  142. # compute the l1 norm of each column
  143. weight_norm = torch.linalg.norm(self.weight.data, ord=1, dim=1)
  144. mask = TopKBinarizer.apply(weight_norm, self.row_pruning_ratio, False)
  145. mask = mask.view(-1, 1)
  146. mask = mask.to(self.weight.device)
  147. elif method == 'topk':
  148. self.row_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1))
  149. self.row_mask_scores.data = self.row_mask_scores.data.to(self.weight.device)
  150. init.kaiming_uniform_(self.row_mask_scores, a=math.sqrt(5))
  151. mask = None
  152. else:
  153. raise NotImplementedError
  154. self.register_buffer('row_pruning_mask', mask)
  155. def enable_head_pruning(self, ratio, method, num_heads):
  156. # Here, we support only topk based pruning
  157. self.num_heads = num_heads
  158. self.head_pruning_ratio = ratio
  159. self.head_pruning_method = method
  160. if method not in ['topk']:
  161. raise NotImplementedError
  162. else:
  163. self.head_pruning_ratio = ratio
  164. self.head_pruning_scores = nn.Parameter(torch.Tensor(1,
  165. self.num_heads)) # we apply the pruning to O matrix
  166. self.head_pruning_scores.data = self.head_pruning_scores.data.to(self.weight.device)
  167. init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5))
  168. def fix_sparse_pruning_helper(self):
  169. mask = self.get_mask(pruning_type='sparse')
  170. self.weight.data = self.weight.data * mask
  171. del self.sparse_pruning_mask
  172. if self.sparse_pruning_method == 'topk':
  173. del self.sparse_mask_scores
  174. self.sparse_pruning_method = None
  175. self.sparse_pruning_enabled = False
  176. return None
  177. def fix_row_col_pruning_helper(self, mask=None, dim_reduction=False):
  178. # This function is used for row/col pruning
  179. # particularly, if we have two back-to-back layers, F1 and F2; when
  180. # we remove rows from F1, we also need to remove columns from F2
  181. # However, if we only have one layer, F1, then we only need to mask pruned
  182. # rows as 0 in F1
  183. if mask is None:
  184. mask = self.get_mask(pruning_type='row').bool()
  185. if dim_reduction:
  186. start_bits = self.weight.start_bits
  187. target_bits = self.weight.target_bits
  188. q_period = self.weight.q_period
  189. self.weight = nn.Parameter(self.weight.data[mask.view(-1), :])
  190. self.weight.start_bits = start_bits
  191. self.weight.target_bits = target_bits
  192. self.weight.q_period = q_period
  193. if self.bias is not None:
  194. self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
  195. self.out_features = self.weight.size(0)
  196. else:
  197. self.weight.data = self.weight.data * mask.view(-1, 1)
  198. if self.bias is not None:
  199. self.bias.data = self.bias.data * mask.view(-1)
  200. del self.row_pruning_mask
  201. if self.row_pruning_method == 'topk':
  202. del self.row_mask_scores
  203. self.row_pruning_method = None
  204. else:
  205. # this is generally for column pruning
  206. start_bits = self.weight.start_bits
  207. target_bits = self.weight.target_bits
  208. q_period = self.weight.q_period
  209. self.weight = nn.Parameter(self.weight.data[:, mask.view(-1)])
  210. self.weight.start_bits = start_bits
  211. self.weight.target_bits = target_bits
  212. self.weight.q_period = q_period
  213. self.in_features = self.weight.size(1)
  214. mask = None
  215. self.row_pruning_enabled = False
  216. return mask
  217. def fix_head_pruning_helper(self, mask=None, num_heads=None, dim_reduction=False):
  218. # similar as row/col pruning, head pruning also needs to prune QKV which is associated with O matrix
  219. num_heads = num_heads if num_heads else self.num_heads
  220. if mask is None:
  221. if self.head_pruning_method == 'topk':
  222. mask = self.get_mask(pruning_type='head').bool()
  223. if dim_reduction:
  224. shape = self.weight.size(0)
  225. start_bits = self.weight.start_bits
  226. target_bits = self.weight.target_bits
  227. q_period = self.weight.q_period
  228. self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads,
  229. -1)[mask.view(-1), :].reshape(-1,
  230. shape).t())
  231. self.weight.start_bits = start_bits
  232. self.weight.target_bits = target_bits
  233. self.weight.q_period = q_period
  234. else:
  235. shape = self.weight.size()
  236. self.weight.data = (self.weight.data.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(
  237. shape[1], shape[0]).t()
  238. if self.head_pruning_method == 'topk':
  239. del self.head_pruning_scores
  240. self.head_pruning_method = None
  241. else:
  242. raise NotImplementedError
  243. else:
  244. start_bits = self.weight.start_bits
  245. target_bits = self.weight.target_bits
  246. q_period = self.weight.q_period
  247. shape = self.weight.size(1)
  248. self.weight = nn.Parameter(self.weight.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape))
  249. self.weight.start_bits = start_bits
  250. self.weight.target_bits = target_bits
  251. self.weight.q_period = q_period
  252. if self.bias is not None:
  253. self.bias = nn.Parameter(self.bias.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1))
  254. self.head_pruning_enabled = False
  255. return mask
  256. def get_mask(self, pruning_type='row'):
  257. if pruning_type == 'sparse':
  258. if self.sparse_pruning_method == 'l1':
  259. return self.sparse_pruning_mask.to(self.weight.device)
  260. elif self.sparse_pruning_method == 'topk':
  261. return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
  262. else:
  263. raise NotImplementedError
  264. if pruning_type == 'row':
  265. if self.row_pruning_method == 'l1':
  266. return self.row_pruning_mask.to(self.weight.device)
  267. elif self.row_pruning_method == 'topk':
  268. return TopKBinarizer.apply(self.row_mask_scores, self.row_pruning_ratio, False)
  269. else:
  270. raise NotImplementedError
  271. elif pruning_type == 'head':
  272. if self.head_pruning_method == 'topk':
  273. return TopKBinarizer.apply(self.head_pruning_scores, self.head_pruning_ratio, False)
  274. else:
  275. raise NotImplementedError
  276. else:
  277. raise NotImplementedError
  278. def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
  279. weight_quantization_enabled_in_forward, quantization_type, num_groups):
  280. self.weight.start_bits = start_bits
  281. self.weight.target_bits = target_bits
  282. self.weight.q_period = quantization_period
  283. self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
  284. if self.weight_quantization_enabled_in_forward:
  285. logger.warning(
  286. "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
  287. )
  288. if self.weight.target_bits >= 3:
  289. if quantization_type == 'symmetric':
  290. self.weight_quantizer = SymQuantizer.apply
  291. else:
  292. self.weight_quantizer = AsymQuantizer.apply
  293. elif self.weight.target_bits == 2:
  294. assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
  295. self.weight_quantizer = TernaryQuantizer.apply
  296. elif self.weight.target_bits == 1:
  297. assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
  298. self.weight_quantizer = BinaryQuantizer.apply
  299. self.weight_quantize_num_groups = num_groups
  300. def fix_weight_quantization(self):
  301. self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
  302. self.weight_quantize_num_groups).data
  303. self.weight_quantization_enabled_in_forward = False
  304. return None
  305. def enable_activation_quantization(self, bits, quantization_type, range_calibration):
  306. assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
  307. self.activation_quantization_bits = bits
  308. self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
  309. if range_calibration == 'static':
  310. self.activation_quantizer = QuantAct(quant_mode=quantization_type)
  311. else:
  312. if quantization_type == 'symmetric':
  313. self.activation_quantizer = SymQuantizer.apply
  314. else:
  315. self.activation_quantizer = AsymQuantizer.apply
  316. def head_pruning_reshape(self, w, mask):
  317. shape = w.shape
  318. return (w.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(shape[1], shape[0]).t()
  319. def forward(self, input, skip_bias_add=False):
  320. if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
  321. weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
  322. self.weight_quantize_num_groups)
  323. bias = self.bias
  324. else:
  325. weight = self.weight
  326. bias = self.bias
  327. if self.sparse_pruning_enabled and self.sparse_pruning_method:
  328. mask = self.get_mask(pruning_type='sparse')
  329. weight = weight * mask.view(self.weight.size())
  330. if self.row_pruning_enabled and self.row_pruning_method:
  331. mask = self.get_mask(pruning_type='row')
  332. weight = weight * mask.view(-1, 1)
  333. if bias is not None:
  334. bias = bias * mask.view(-1)
  335. if self.head_pruning_enabled and self.head_pruning_method:
  336. mask = self.get_mask(pruning_type='head')
  337. weight = self.head_pruning_reshape(weight, mask)
  338. if self.activation_quantization_enabled:
  339. if 'dynamic' in self.activation_quantization_method:
  340. num_groups = input.numel() // input.size(-1)
  341. else:
  342. num_groups = 1
  343. input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
  344. if skip_bias_add:
  345. # used for mpu linear layers
  346. output = nn.functional.linear(input, weight, None)
  347. return output, bias
  348. else:
  349. output = nn.functional.linear(input, weight, bias)
  350. return output
  351. class Conv2dLayer_Compress(nn.Conv2d):
  352. """
  353. Conv2D layer with compression.
  354. """
  355. def __init__(self, *kargs):
  356. super(Conv2dLayer_Compress, self).__init__(*kargs)
  357. self.sparse_pruning_method = None
  358. self.channel_pruning_method = None
  359. self.activation_quantization_method = None
  360. self.weight.start_bits = None
  361. self.weight.target_bits = None
  362. self.weight.q_period = None
  363. self.weight_quantization_enabled_in_forward = False
  364. self.sparse_pruning_enabled = False
  365. self.channel_pruning_enabled = False
  366. self.activation_quantization_enabled = False
  367. def __repr__(self):
  368. s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
  369. ', stride={stride}')
  370. if self.padding != (0, ) * len(self.padding):
  371. s += ', padding={padding}'
  372. if self.dilation != (1, ) * len(self.dilation):
  373. s += ', dilation={dilation}'
  374. if self.output_padding != (0, ) * len(self.output_padding):
  375. s += ', output_padding={output_padding}'
  376. if self.groups != 1:
  377. s += ', groups={groups}'
  378. if self.bias is None:
  379. s += ', bias=False'
  380. if self.padding_mode != 'zeros':
  381. s += ', padding_mode={padding_mode}'
  382. output = s.format(**self.__dict__)
  383. return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format(
  384. self.sparse_pruning_method is not None, self.channel_pruning_method is not None,
  385. self.activation_quantization_method is not None, self.weight.target_bits)
  386. def enable_sparse_pruning(self, ratio, method):
  387. self.sparse_pruning_ratio = ratio
  388. self.sparse_pruning_method = method
  389. if method == 'l1':
  390. weight_norm = torch.abs(self.weight.data)
  391. mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
  392. mask = mask.view(self.weight.size())
  393. mask = mask.to(self.weight.device)
  394. elif method == 'topk':
  395. self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
  396. self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
  397. init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
  398. mask = None
  399. else:
  400. raise NotImplementedError
  401. self.register_buffer('sparse_pruning_mask', mask)
  402. def enable_channel_pruning(self, ratio, method):
  403. # Here, we support two cases: L1 norm based pruning and topk based pruning
  404. self.channel_pruning_ratio = ratio
  405. self.channel_pruning_method = method
  406. if method == 'l1':
  407. # compute the l1 norm of each conv2d kernel (the last three dimension)
  408. weight_norm = torch.linalg.norm(self.weight.data, ord=1, dim=[1, 2, 3])
  409. mask = TopKBinarizer.apply(weight_norm, self.channel_pruning_ratio, False)
  410. mask = mask.view(-1, 1, 1, 1)
  411. mask = mask.to(self.weight.device)
  412. elif method == 'topk':
  413. self.channel_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1, 1, 1))
  414. self.channel_mask_scores.data = self.channel_mask_scores.data.to(self.weight.device)
  415. init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5))
  416. mask = None
  417. else:
  418. raise NotImplementedError
  419. self.register_buffer('channel_pruning_mask', mask)
  420. def fix_sparse_pruning_helper(self):
  421. mask = self.get_mask(pruning_type='sparse')
  422. self.weight.data = self.weight.data * mask
  423. del self.sparse_pruning_mask
  424. if self.sparse_pruning_method == 'topk':
  425. del self.sparse_mask_scores
  426. self.sparse_pruning_method = None
  427. self.sparse_pruning_enabled = False
  428. return None
  429. def fix_channel_pruning_helper(self, mask=None, dim_reduction=False):
  430. if mask is None:
  431. if self.channel_pruning_method in ['l1', 'topk']:
  432. mask = self.get_mask(pruning_type='channel').bool()
  433. if dim_reduction:
  434. start_bits = self.weight.start_bits
  435. target_bits = self.weight.target_bits
  436. q_period = self.weight.q_period
  437. self.weight = nn.Parameter(self.weight.data[mask.view(-1), ...])
  438. self.weight.start_bits = start_bits
  439. self.weight.target_bits = target_bits
  440. self.weight.q_period = q_period
  441. if self.bias is not None:
  442. self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
  443. else:
  444. self.weight.data = self.weight.data * mask.view(-1, 1, 1, 1)
  445. if self.bias is not None:
  446. self.bias.data = self.bias.data * mask.view(-1)
  447. del self.channel_pruning_mask
  448. if self.channel_pruning_method == 'topk':
  449. del self.channel_mask_scores
  450. self.channel_pruning_method = None
  451. else:
  452. raise NotImplementedError
  453. else:
  454. start_bits = self.weight.start_bits
  455. target_bits = self.weight.target_bits
  456. q_period = self.weight.q_period
  457. self.weight = nn.Parameter(self.weight.data[:, mask.view(-1), ...])
  458. self.weight.start_bits = start_bits
  459. self.weight.target_bits = target_bits
  460. self.weight.q_period = q_period
  461. mask = None
  462. self.channel_pruning_enabled = False
  463. return mask
  464. def get_mask(self, pruning_type='sparse'):
  465. if pruning_type == 'sparse':
  466. if self.sparse_pruning_method == 'l1':
  467. return self.sparse_pruning_mask.to(self.weight.device)
  468. elif self.sparse_pruning_method == 'topk':
  469. return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
  470. else:
  471. raise NotImplementedError
  472. elif pruning_type == 'channel':
  473. if self.channel_pruning_method == 'l1':
  474. return self.channel_pruning_mask.to(self.weight.device)
  475. elif self.channel_pruning_method == 'topk':
  476. return TopKBinarizer.apply(self.channel_mask_scores, self.channel_pruning_ratio, False)
  477. else:
  478. raise NotImplementedError
  479. else:
  480. raise NotImplementedError
  481. def fix_weight_quantization(self):
  482. self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
  483. self.weight_quantize_num_groups).data
  484. self.weight_quantization_enabled_in_forward = False
  485. return None
  486. def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
  487. weight_quantization_enabled_in_forward, quantization_type, num_groups):
  488. self.weight.start_bits = start_bits
  489. self.weight.target_bits = target_bits
  490. self.weight.q_period = quantization_period
  491. self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
  492. if self.weight_quantization_enabled_in_forward:
  493. assert self.weight.target_bits >= 4, 'Only >=4 bits weight quantization are supported during forward pass for now'
  494. logger.warning(
  495. "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
  496. )
  497. if quantization_type == 'symmetric':
  498. self.weight_quantizer = SymQuantizer.apply
  499. else:
  500. self.weight_quantizer = AsymQuantizer.apply
  501. self.weight_quantize_num_groups = num_groups
  502. def enable_activation_quantization(self, bits, quantization_type, range_calibration):
  503. assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
  504. self.activation_quantization_bits = bits
  505. self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
  506. if range_calibration == 'static':
  507. self.activation_quantizer = QuantAct(quant_mode=quantization_type)
  508. else:
  509. if quantization_type == 'symmetric':
  510. self.activation_quantizer = SymQuantizer.apply
  511. else:
  512. self.activation_quantizer = AsymQuantizer.apply
  513. def forward(self, input):
  514. if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
  515. weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
  516. self.weight_quantize_num_groups)
  517. bias = self.bias
  518. else:
  519. weight = self.weight
  520. bias = self.bias
  521. if self.sparse_pruning_enabled and self.sparse_pruning_method:
  522. mask = self.get_mask(pruning_type='sparse')
  523. weight = weight * mask.view(self.weight.size())
  524. if self.channel_pruning_enabled:
  525. mask = self.get_mask(pruning_type='channel')
  526. weight = weight * mask.view(-1, 1, 1, 1)
  527. if bias is not None:
  528. bias = bias * mask.view(-1)
  529. if self.activation_quantization_enabled:
  530. if 'dynamic' in self.activation_quantization_method:
  531. num_groups = input.numel() // input[0].numel()
  532. else:
  533. num_groups = 1
  534. input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
  535. return nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
  536. class BNLayer_Compress(nn.BatchNorm2d):
  537. def fix_channel_pruning_helper(self, mask, dim_reduction=True):
  538. self.weight = nn.Parameter(self.weight.data[mask.view(-1)])
  539. self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
  540. self.running_mean = self.running_mean[mask.view(-1)]
  541. self.running_var = self.running_var[mask.view(-1)]
  542. def _reduce(input_):
  543. """All-reduce the input tensor across model parallel group."""
  544. group = g_mpu.get_model_parallel_group()
  545. # Bypass the function if we are using only 1 GPU.
  546. if dist.get_world_size(group=group) == 1:
  547. return input_
  548. # All-reduce.
  549. dist.all_reduce(input_, group=group)
  550. return input_
  551. def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
  552. """Split a tensor along its last dimension.
  553. Arguments:
  554. tensor: input tensor.
  555. num_partitions: number of partitions to split the tensor
  556. contiguous_split_chunks: If True, make each chunk contiguous
  557. in memory.
  558. """
  559. # Get the size and dimension.
  560. last_dim = tensor.dim() - 1
  561. assert tensor.size()[last_dim] % num_partitions == 0
  562. last_dim_size = tensor.size()[last_dim] // num_partitions
  563. # Split.
  564. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  565. # Note: torch.split does not create contiguous tensors by default.
  566. if contiguous_split_chunks:
  567. return tuple(chunk.contiguous() for chunk in tensor_list)
  568. return tensor_list
  569. def _split(input_):
  570. """Split the tensor along its last dimension and keep the
  571. corresponding slice."""
  572. group = g_mpu.get_model_parallel_group()
  573. # Bypass the function if we are using only 1 GPU.
  574. if dist.get_world_size(group=group) == 1:
  575. return input_
  576. # Split along last dimension.
  577. world_size = dist.get_world_size(group=group)
  578. input_list = split_tensor_along_last_dim(input_, world_size)
  579. # Note: torch.split does not create contiguous tensors by default.
  580. rank = dist.get_rank(group=group)
  581. output = input_list[rank].contiguous()
  582. return output
  583. def _gather(input_):
  584. """Gather tensors and concatenate along the last dimension."""
  585. group = g_mpu.get_model_parallel_group()
  586. # Bypass the function if we are using only 1 GPU.
  587. if dist.get_world_size(group=group) == 1:
  588. return input_
  589. # Size and dimension.
  590. last_dim = input_.dim() - 1
  591. rank = dist.get_rank(group=group)
  592. world_size = dist.get_world_size(group=group)
  593. tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
  594. tensor_list[rank] = input_
  595. dist.all_gather(tensor_list, input_, group=group)
  596. # Note: torch.cat already creates a contiguous tensor.
  597. output = torch.cat(tensor_list, dim=last_dim).contiguous()
  598. return output
  599. class _CopyToModelParallelRegion(torch.autograd.Function):
  600. """Pass the input to the model parallel region."""
  601. @staticmethod
  602. def forward(ctx, input_):
  603. return input_
  604. @staticmethod
  605. def backward(ctx, grad_output):
  606. return _reduce(grad_output)
  607. class _ReduceFromModelParallelRegion(torch.autograd.Function):
  608. """All-reduce the input from the model parallel region."""
  609. @staticmethod
  610. def forward(ctx, input_):
  611. return _reduce(input_)
  612. @staticmethod
  613. def backward(ctx, grad_output):
  614. return grad_output
  615. class _ScatterToModelParallelRegion(torch.autograd.Function):
  616. """Split the input and keep only the corresponding chuck to the rank."""
  617. @staticmethod
  618. def forward(ctx, input_):
  619. return _split(input_)
  620. @staticmethod
  621. def backward(ctx, grad_output):
  622. return _gather(grad_output)
  623. class _GatherFromModelParallelRegion(torch.autograd.Function):
  624. """Gather the input from model parallel region and concatenate."""
  625. @staticmethod
  626. def forward(ctx, input_):
  627. return _gather(input_)
  628. @staticmethod
  629. def backward(ctx, grad_output):
  630. return _split(grad_output)
  631. # -----------------
  632. # Helper functions.
  633. # -----------------
  634. def copy_to_model_parallel_region(input_):
  635. return _CopyToModelParallelRegion.apply(input_)
  636. def reduce_from_model_parallel_region(input_):
  637. return _ReduceFromModelParallelRegion.apply(input_)
  638. def scatter_to_model_parallel_region(input_):
  639. return _ScatterToModelParallelRegion.apply(input_)
  640. def gather_from_model_parallel_region(input_):
  641. return _GatherFromModelParallelRegion.apply(input_)
  642. class ColumnParallelLinear_Compress(LinearLayer_Compress):
  643. def __init__(self, mpu, input_size, output_size, bias=True, gather_output=True, skip_bias_add=False):
  644. # Keep input parameters
  645. global g_mpu
  646. g_mpu = mpu
  647. self.input_size = input_size
  648. self.output_size = output_size
  649. self.gather_output = gather_output
  650. self.skip_bias_add = skip_bias_add
  651. # Divide the weight matrix along the last dimension.
  652. world_size = mpu.get_model_parallel_world_size()
  653. assert output_size % world_size == 0
  654. self.output_size_per_partition = output_size // world_size
  655. super(ColumnParallelLinear_Compress, self).__init__(self.input_size, self.output_size_per_partition, bias=bias)
  656. def forward(self, input_):
  657. # Set up backprop all-reduce.
  658. input_parallel = copy_to_model_parallel_region(input_)
  659. # Matrix multiply.
  660. if self.skip_bias_add:
  661. output_parallel, bias = super().forward(input_parallel, True)
  662. else:
  663. output_parallel = super().forward(input_parallel)
  664. bias = None
  665. if self.gather_output:
  666. # All-gather across the partitions.
  667. output = gather_from_model_parallel_region(output_parallel)
  668. else:
  669. output = output_parallel
  670. return output, bias
  671. class RowParallelLinear_Compress(LinearLayer_Compress):
  672. def __init__(self, mpu, input_size, output_size, bias=True, input_is_parallel=False, skip_bias_add=False):
  673. # Keep input parameters
  674. global g_mpu
  675. g_mpu = mpu
  676. self.input_size = input_size
  677. self.output_size = output_size
  678. self.input_is_parallel = input_is_parallel
  679. self.skip_bias_add = skip_bias_add
  680. # Divide the weight matrix along the last dimension.
  681. world_size = mpu.get_model_parallel_world_size()
  682. assert input_size % world_size == 0
  683. self.input_size_per_partition = input_size // world_size
  684. super(RowParallelLinear_Compress, self).__init__(self.input_size_per_partition, self.output_size, bias=bias)
  685. def forward(self, input_):
  686. # Set up backprop all-reduce.
  687. if self.input_is_parallel:
  688. input_parallel = input_
  689. else:
  690. input_parallel = scatter_to_model_parallel_region(input_)
  691. # Matrix multiply.
  692. output_parallel, bias = super().forward(input_parallel, True)
  693. # All-reduce across all the partitions.
  694. output_ = reduce_from_model_parallel_region(output_parallel)
  695. if not self.skip_bias_add:
  696. if bias is not None:
  697. output = output_ + bias
  698. else:
  699. output = output_
  700. output_bias = None
  701. else:
  702. output = output_
  703. output_bias = bias
  704. return output, output_bias