basic_layer.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import torch
  3. import math
  4. from torch import nn
  5. from torch.nn import init
  6. import deepspeed.comm as dist
  7. from .utils import TopKBinarizer, SymQuantizer, AsymQuantizer, TernaryQuantizer, BinaryQuantizer
  8. from deepspeed.utils import logger
  9. g_mpu = None
  10. class QuantAct(nn.Module):
  11. """
  12. Class to quantize given activations. Note that when using this function, the input acttivation quantization range will be fixed for all
  13. tokens/images for inference. This generally will affect some accuracy but achieve better latency performance.
  14. Parameters:
  15. ----------
  16. act_range_momentum : float, default 0.95
  17. Momentum for updating the activation quantization range.
  18. quant_mode : str, default 'symmetric'
  19. """
  20. def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'):
  21. super(QuantAct, self).__init__()
  22. self.act_range_momentum = act_range_momentum
  23. self.quant_mode = quant_mode
  24. if quant_mode == 'symmetric':
  25. self.act_function = SymQuantizer.apply
  26. else:
  27. self.act_function = AsymQuantizer.apply
  28. self.register_buffer('x_min_max', torch.zeros(2))
  29. def forward(self, x, num_bits, *args):
  30. """
  31. x: the activation that we need to quantize
  32. num_bits: the number of bits we need to quantize the activation to
  33. *args: some extra arguments that are useless but needed for align with the interface of other quantization functions
  34. """
  35. if self.training:
  36. x_min = x.data.min()
  37. x_max = x.data.max()
  38. # Initialization
  39. if self.x_min_max[0] == self.x_min_max[1]:
  40. self.x_min_max[0] = x_min
  41. self.x_min_max[1] = x_max
  42. # if do not need momentum, please set self.act_range_momentum = 0
  43. self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (
  44. 1 - self.act_range_momentum)
  45. self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (
  46. 1 - self.act_range_momentum)
  47. x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1])
  48. return x_q
  49. class Embedding_Compress(nn.Embedding):
  50. def __init__(self, *kargs):
  51. super(Embedding_Compress, self).__init__(*kargs)
  52. self.weight.start_bits = None
  53. self.weight.target_bits = None
  54. self.weight.q_period = None
  55. self.weight_quantization_enabled_in_forward = False
  56. self.weight_quantization_enabled = False
  57. def extra_repr(self):
  58. return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format(
  59. self.num_embeddings,
  60. self.embedding_dim,
  61. self.weight.target_bits)
  62. def enable_weight_quantization(self,
  63. start_bits,
  64. target_bits,
  65. quantization_period,
  66. weight_quantization_enabled_in_forward,
  67. quantization_type,
  68. num_groups):
  69. self.weight.start_bits = start_bits
  70. self.weight.target_bits = target_bits
  71. self.weight.q_period = quantization_period
  72. self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
  73. if self.weight_quantization_enabled_in_forward:
  74. logger.warning(
  75. "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
  76. )
  77. if self.weight.target_bits >= 3:
  78. if quantization_type == 'symmetric':
  79. self.weight_quantizer = SymQuantizer.apply
  80. else:
  81. self.weight_quantizer = AsymQuantizer.apply
  82. elif self.weight.target_bits == 2:
  83. assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
  84. self.weight_quantizer = TernaryQuantizer.apply
  85. elif self.weight.target_bits == 1:
  86. assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
  87. self.weight_quantizer = BinaryQuantizer.apply
  88. # for embedding, we always use token-wise quantization
  89. self.weight_quantize_num_groups = self.weight.size(0)
  90. def fix_weight_quantization(self):
  91. self.weight.data = self.weight_quantizer(self.weight,
  92. self.weight.target_bits,
  93. None,
  94. None,
  95. self.weight_quantize_num_groups).data
  96. self.weight_quantization_enabled_in_forward = False
  97. return None
  98. def forward(self, input):
  99. if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
  100. weight = self.weight_quantizer(self.weight,
  101. self.weight.target_bits,
  102. None,
  103. None,
  104. self.weight_quantize_num_groups)
  105. else:
  106. weight = self.weight
  107. out = nn.functional.embedding(input,
  108. weight,
  109. self.padding_idx,
  110. self.max_norm,
  111. self.norm_type,
  112. self.scale_grad_by_freq,
  113. self.sparse)
  114. return out
  115. class LinearLayer_Compress(nn.Linear):
  116. """
  117. Linear layer with compression.
  118. """
  119. def __init__(self, *kargs, bias=True):
  120. super(LinearLayer_Compress, self).__init__(*kargs, bias=bias)
  121. self.sparse_pruning_method = None
  122. self.row_pruning_method = None
  123. self.head_pruning_method = None
  124. self.activation_quantization_method = None
  125. self.weight.start_bits = None
  126. self.weight.target_bits = None
  127. self.weight.q_period = None
  128. self.weight_quantization_enabled_in_forward = False
  129. self.weight_quantization_enabled = False
  130. self.sparse_pruning_enabled = False
  131. self.row_pruning_enabled = False
  132. self.head_pruning_enabled = False
  133. self.activation_quantization_enabled = False
  134. def extra_repr(self):
  135. return 'in_features={}, out_features={}, bias={}, sparse pruning={}, row pruning={}, head pruning={}, activation quantization={}, weight_quantization={}'.format(
  136. self.in_features, self.out_features, self.bias is not None, self.sparse_pruning_method is not None, \
  137. self.row_pruning_method is not None, self.head_pruning_method is not None, self.activation_quantization_method is not None, self.weight.target_bits)
  138. def enable_sparse_pruning(self, ratio, method):
  139. # Here, we support two cases: L1 norm based pruning and topk based pruning
  140. self.sparse_pruning_ratio = ratio
  141. self.sparse_pruning_method = method
  142. if method == 'l1':
  143. weight_norm = torch.abs(self.weight.data)
  144. mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
  145. mask = mask.view(self.weight.size())
  146. mask = mask.to(self.weight.device)
  147. elif method == 'topk':
  148. self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
  149. self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(
  150. self.weight.device)
  151. init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
  152. mask = None
  153. else:
  154. raise NotImplementedError
  155. self.register_buffer('sparse_pruning_mask', mask)
  156. def enable_row_pruning(self, ratio, method):
  157. # Here, we support two cases: L1 norm based pruning and topk based pruning
  158. self.row_pruning_ratio = ratio
  159. self.row_pruning_method = method
  160. if method == 'l1':
  161. # compute the l1 norm of each column
  162. weight_norm = torch.norm(self.weight.data, p=1, dim=1)
  163. mask = TopKBinarizer.apply(weight_norm, self.row_pruning_ratio, False)
  164. mask = mask.view(-1, 1)
  165. mask = mask.to(self.weight.device)
  166. elif method == 'topk':
  167. self.row_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1))
  168. self.row_mask_scores.data = self.row_mask_scores.data.to(self.weight.device)
  169. init.kaiming_uniform_(self.row_mask_scores, a=math.sqrt(5))
  170. mask = None
  171. else:
  172. raise NotImplementedError
  173. self.register_buffer('row_pruning_mask', mask)
  174. def enable_head_pruning(self, ratio, method, num_heads):
  175. # Here, we support only topk based pruning
  176. self.num_heads = num_heads
  177. self.head_pruning_ratio = ratio
  178. self.head_pruning_method = method
  179. if method not in ['topk']:
  180. raise NotImplementedError
  181. else:
  182. self.head_pruning_ratio = ratio
  183. self.head_pruning_scores = nn.Parameter(torch.Tensor(
  184. 1,
  185. self.num_heads)) # we apply the pruning to O matrix
  186. self.head_pruning_scores.data = self.head_pruning_scores.data.to(
  187. self.weight.device)
  188. init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5))
  189. def fix_sparse_pruning_helper(self):
  190. mask = self.get_mask(pruning_type='sparse')
  191. self.weight.data = self.weight.data * mask
  192. del self.sparse_pruning_mask
  193. if self.sparse_pruning_method == 'topk':
  194. del self.sparse_mask_scores
  195. self.sparse_pruning_method = None
  196. self.sparse_pruning_enabled = False
  197. return None
  198. def fix_row_col_pruning_helper(self, mask=None, dim_reduction=False):
  199. # This function is used for row/col pruning
  200. # particularly, if we have two back-to-back layers, F1 and F2; when
  201. # we remove rows from F1, we also need to remove columns from F2
  202. # However, if we only have one layer, F1, then we only need to mask pruned
  203. # rows as 0 in F1
  204. if mask is None:
  205. mask = self.get_mask(pruning_type='row').bool()
  206. if dim_reduction:
  207. start_bits = self.weight.start_bits
  208. target_bits = self.weight.target_bits
  209. q_period = self.weight.q_period
  210. self.weight = nn.Parameter(self.weight.data[mask.view(-1), :])
  211. self.weight.start_bits = start_bits
  212. self.weight.target_bits = target_bits
  213. self.weight.q_period = q_period
  214. if self.bias is not None:
  215. self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
  216. self.out_features = self.weight.size(0)
  217. else:
  218. self.weight.data = self.weight.data * mask.view(-1, 1)
  219. if self.bias is not None:
  220. self.bias.data = self.bias.data * mask.view(-1)
  221. del self.row_pruning_mask
  222. if self.row_pruning_method == 'topk':
  223. del self.row_mask_scores
  224. self.row_pruning_method = None
  225. else:
  226. # this is generally for column pruning
  227. start_bits = self.weight.start_bits
  228. target_bits = self.weight.target_bits
  229. q_period = self.weight.q_period
  230. self.weight = nn.Parameter(self.weight.data[:, mask.view(-1)])
  231. self.weight.start_bits = start_bits
  232. self.weight.target_bits = target_bits
  233. self.weight.q_period = q_period
  234. self.in_features = self.weight.size(1)
  235. mask = None
  236. self.row_pruning_enabled = False
  237. return mask
  238. def fix_head_pruning_helper(self, mask=None, num_heads=None, dim_reduction=False):
  239. # similar as row/col pruning, head pruning also needs to prune QKV which is associated with O matrix
  240. num_heads = num_heads if num_heads else self.num_heads
  241. if mask is None:
  242. if self.head_pruning_method == 'topk':
  243. mask = self.get_mask(pruning_type='head').bool()
  244. if dim_reduction:
  245. shape = self.weight.size(0)
  246. start_bits = self.weight.start_bits
  247. target_bits = self.weight.target_bits
  248. q_period = self.weight.q_period
  249. self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape).t())
  250. self.weight.start_bits = start_bits
  251. self.weight.target_bits = target_bits
  252. self.weight.q_period = q_period
  253. else:
  254. shape = self.weight.size()
  255. self.weight.data = (self.weight.data.t().reshape(self.num_heads,
  256. -1) *
  257. mask.view(-1,
  258. 1)).reshape(shape[1],
  259. shape[0]).t()
  260. if self.head_pruning_method == 'topk':
  261. del self.head_pruning_scores
  262. self.head_pruning_method = None
  263. else:
  264. raise NotImplementedError
  265. else:
  266. start_bits = self.weight.start_bits
  267. target_bits = self.weight.target_bits
  268. q_period = self.weight.q_period
  269. shape = self.weight.size(1)
  270. self.weight = nn.Parameter(self.weight.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape))
  271. self.weight.start_bits = start_bits
  272. self.weight.target_bits = target_bits
  273. self.weight.q_period = q_period
  274. if self.bias is not None:
  275. self.bias = nn.Parameter(self.bias.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1))
  276. self.head_pruning_enabled = False
  277. return mask
  278. def get_mask(self, pruning_type='row'):
  279. if pruning_type == 'sparse':
  280. if self.sparse_pruning_method == 'l1':
  281. return self.sparse_pruning_mask.to(self.weight.device)
  282. elif self.sparse_pruning_method == 'topk':
  283. return TopKBinarizer.apply(self.sparse_mask_scores,
  284. self.sparse_pruning_ratio,
  285. False)
  286. else:
  287. raise NotImplementedError
  288. if pruning_type == 'row':
  289. if self.row_pruning_method == 'l1':
  290. return self.row_pruning_mask.to(self.weight.device)
  291. elif self.row_pruning_method == 'topk':
  292. return TopKBinarizer.apply(self.row_mask_scores,
  293. self.row_pruning_ratio,
  294. False)
  295. else:
  296. raise NotImplementedError
  297. elif pruning_type == 'head':
  298. if self.head_pruning_method == 'topk':
  299. return TopKBinarizer.apply(self.head_pruning_scores,
  300. self.head_pruning_ratio,
  301. False)
  302. else:
  303. raise NotImplementedError
  304. else:
  305. raise NotImplementedError
  306. def enable_weight_quantization(self,
  307. start_bits,
  308. target_bits,
  309. quantization_period,
  310. weight_quantization_enabled_in_forward,
  311. quantization_type,
  312. num_groups):
  313. self.weight.start_bits = start_bits
  314. self.weight.target_bits = target_bits
  315. self.weight.q_period = quantization_period
  316. self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
  317. if self.weight_quantization_enabled_in_forward:
  318. logger.warning(
  319. "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
  320. )
  321. if self.weight.target_bits >= 3:
  322. if quantization_type == 'symmetric':
  323. self.weight_quantizer = SymQuantizer.apply
  324. else:
  325. self.weight_quantizer = AsymQuantizer.apply
  326. elif self.weight.target_bits == 2:
  327. assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
  328. self.weight_quantizer = TernaryQuantizer.apply
  329. elif self.weight.target_bits == 1:
  330. assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
  331. self.weight_quantizer = BinaryQuantizer.apply
  332. self.weight_quantize_num_groups = num_groups
  333. def fix_weight_quantization(self):
  334. self.weight.data = self.weight_quantizer(self.weight,
  335. self.weight.target_bits,
  336. None,
  337. None,
  338. self.weight_quantize_num_groups).data
  339. self.weight_quantization_enabled_in_forward = False
  340. return None
  341. def enable_activation_quantization(self, bits, quantization_type, range_calibration):
  342. assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
  343. self.activation_quantization_bits = bits
  344. self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
  345. if range_calibration == 'static':
  346. self.activation_quantizer = QuantAct(quant_mode=quantization_type)
  347. else:
  348. if quantization_type == 'symmetric':
  349. self.activation_quantizer = SymQuantizer.apply
  350. else:
  351. self.activation_quantizer = AsymQuantizer.apply
  352. def head_pruning_reshape(self, w, mask):
  353. shape = w.shape
  354. return (w.t().reshape(self.num_heads,
  355. -1) * mask.view(-1,
  356. 1)).reshape(shape[1],
  357. shape[0]).t()
  358. def forward(self, input, skip_bias_add=False):
  359. if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
  360. weight = self.weight_quantizer(self.weight,
  361. self.weight.target_bits,
  362. None,
  363. None,
  364. self.weight_quantize_num_groups)
  365. bias = self.bias
  366. else:
  367. weight = self.weight
  368. bias = self.bias
  369. if self.sparse_pruning_enabled and self.sparse_pruning_method:
  370. mask = self.get_mask(pruning_type='sparse')
  371. weight = weight * mask.view(self.weight.size())
  372. if self.row_pruning_enabled and self.row_pruning_method:
  373. mask = self.get_mask(pruning_type='row')
  374. weight = weight * mask.view(-1, 1)
  375. if bias is not None:
  376. bias = bias * mask.view(-1)
  377. if self.head_pruning_enabled and self.head_pruning_method:
  378. mask = self.get_mask(pruning_type='head')
  379. weight = self.head_pruning_reshape(weight, mask)
  380. if self.activation_quantization_enabled:
  381. if 'dynamic' in self.activation_quantization_method:
  382. num_groups = input.numel() // input.size(-1)
  383. else:
  384. num_groups = 1
  385. input = self.activation_quantizer(input,
  386. self.activation_quantization_bits,
  387. None,
  388. None,
  389. num_groups)
  390. if skip_bias_add:
  391. # used for mpu linear layers
  392. output = nn.functional.linear(input, weight, None)
  393. return output, bias
  394. else:
  395. output = nn.functional.linear(input, weight, bias)
  396. return output
  397. class Conv2dLayer_Compress(nn.Conv2d):
  398. """
  399. Conv2D layer with compression.
  400. """
  401. def __init__(self, *kargs):
  402. super(Conv2dLayer_Compress, self).__init__(*kargs)
  403. self.sparse_pruning_method = None
  404. self.channel_pruning_method = None
  405. self.activation_quantization_method = None
  406. self.weight.start_bits = None
  407. self.weight.target_bits = None
  408. self.weight.q_period = None
  409. self.weight_quantization_enabled_in_forward = False
  410. self.sparse_pruning_enabled = False
  411. self.channel_pruning_enabled = False
  412. self.activation_quantization_enabled = False
  413. def __repr__(self):
  414. s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
  415. ', stride={stride}')
  416. if self.padding != (0, ) * len(self.padding):
  417. s += ', padding={padding}'
  418. if self.dilation != (1, ) * len(self.dilation):
  419. s += ', dilation={dilation}'
  420. if self.output_padding != (0, ) * len(self.output_padding):
  421. s += ', output_padding={output_padding}'
  422. if self.groups != 1:
  423. s += ', groups={groups}'
  424. if self.bias is None:
  425. s += ', bias=False'
  426. if self.padding_mode != 'zeros':
  427. s += ', padding_mode={padding_mode}'
  428. output = s.format(**self.__dict__)
  429. return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format(
  430. self.sparse_pruning_method is not None,
  431. self.channel_pruning_method is not None,
  432. self.activation_quantization_method is not None,
  433. self.weight.target_bits)
  434. def enable_sparse_pruning(self, ratio, method):
  435. self.sparse_pruning_ratio = ratio
  436. self.sparse_pruning_method = method
  437. if method == 'l1':
  438. weight_norm = torch.abs(self.weight.data)
  439. mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
  440. mask = mask.view(self.weight.size())
  441. mask = mask.to(self.weight.device)
  442. elif method == 'topk':
  443. self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
  444. self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(
  445. self.weight.device)
  446. init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
  447. mask = None
  448. else:
  449. raise NotImplementedError
  450. self.register_buffer('sparse_pruning_mask', mask)
  451. def enable_channel_pruning(self, ratio, method):
  452. # Here, we support two cases: L1 norm based pruning and topk based pruning
  453. self.channel_pruning_ratio = ratio
  454. self.channel_pruning_method = method
  455. if method == 'l1':
  456. # compute the l1 norm of each conv2d kernel (the last three dimension)
  457. weight_norm = torch.norm(self.weight.data, p=1, dim=[1, 2, 3])
  458. mask = TopKBinarizer.apply(weight_norm, self.channel_pruning_ratio, False)
  459. mask = mask.view(-1, 1, 1, 1)
  460. mask = mask.to(self.weight.device)
  461. elif method == 'topk':
  462. self.channel_mask_scores = nn.Parameter(
  463. torch.Tensor(self.weight.size(0),
  464. 1,
  465. 1,
  466. 1))
  467. self.channel_mask_scores.data = self.channel_mask_scores.data.to(
  468. self.weight.device)
  469. init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5))
  470. mask = None
  471. else:
  472. raise NotImplementedError
  473. self.register_buffer('channel_pruning_mask', mask)
  474. def fix_sparse_pruning_helper(self):
  475. mask = self.get_mask(pruning_type='sparse')
  476. self.weight.data = self.weight.data * mask
  477. del self.sparse_pruning_mask
  478. if self.sparse_pruning_method == 'topk':
  479. del self.sparse_mask_scores
  480. self.sparse_pruning_method = None
  481. self.sparse_pruning_enabled = False
  482. return None
  483. def fix_channel_pruning_helper(self, mask=None, dim_reduction=False):
  484. if mask is None:
  485. if self.channel_pruning_method in ['l1', 'topk']:
  486. mask = self.get_mask(pruning_type='channel').bool()
  487. if dim_reduction:
  488. start_bits = self.weight.start_bits
  489. target_bits = self.weight.target_bits
  490. q_period = self.weight.q_period
  491. self.weight = nn.Parameter(self.weight.data[mask.view(-1), ...])
  492. self.weight.start_bits = start_bits
  493. self.weight.target_bits = target_bits
  494. self.weight.q_period = q_period
  495. if self.bias is not None:
  496. self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
  497. else:
  498. self.weight.data = self.weight.data * mask.view(-1, 1, 1, 1)
  499. if self.bias is not None:
  500. self.bias.data = self.bias.data * mask.view(-1)
  501. del self.channel_pruning_mask
  502. if self.channel_pruning_method == 'topk':
  503. del self.channel_mask_scores
  504. self.channel_pruning_method = None
  505. else:
  506. raise NotImplementedError
  507. else:
  508. start_bits = self.weight.start_bits
  509. target_bits = self.weight.target_bits
  510. q_period = self.weight.q_period
  511. self.weight = nn.Parameter(self.weight.data[:, mask.view(-1), ...])
  512. self.weight.start_bits = start_bits
  513. self.weight.target_bits = target_bits
  514. self.weight.q_period = q_period
  515. mask = None
  516. self.channel_pruning_enabled = False
  517. return mask
  518. def get_mask(self, pruning_type='sparse'):
  519. if pruning_type == 'sparse':
  520. if self.sparse_pruning_method == 'l1':
  521. return self.sparse_pruning_mask.to(self.weight.device)
  522. elif self.sparse_pruning_method == 'topk':
  523. return TopKBinarizer.apply(self.sparse_mask_scores,
  524. self.sparse_pruning_ratio,
  525. False)
  526. else:
  527. raise NotImplementedError
  528. elif pruning_type == 'channel':
  529. if self.channel_pruning_method == 'l1':
  530. return self.channel_pruning_mask.to(self.weight.device)
  531. elif self.channel_pruning_method == 'topk':
  532. return TopKBinarizer.apply(self.channel_mask_scores,
  533. self.channel_pruning_ratio,
  534. False)
  535. else:
  536. raise NotImplementedError
  537. else:
  538. raise NotImplementedError
  539. def fix_weight_quantization(self):
  540. self.weight.data = self.weight_quantizer(self.weight,
  541. self.weight.target_bits,
  542. None,
  543. None,
  544. self.weight_quantize_num_groups).data
  545. self.weight_quantization_enabled_in_forward = False
  546. return None
  547. def enable_weight_quantization(self,
  548. start_bits,
  549. target_bits,
  550. quantization_period,
  551. weight_quantization_enabled_in_forward,
  552. quantization_type,
  553. num_groups):
  554. self.weight.start_bits = start_bits
  555. self.weight.target_bits = target_bits
  556. self.weight.q_period = quantization_period
  557. self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
  558. if self.weight_quantization_enabled_in_forward:
  559. assert self.weight.target_bits >= 4, 'Only >=4 bits weight quantization are supported during forward pass for now'
  560. logger.warning(
  561. "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
  562. )
  563. if quantization_type == 'symmetric':
  564. self.weight_quantizer = SymQuantizer.apply
  565. else:
  566. self.weight_quantizer = AsymQuantizer.apply
  567. self.weight_quantize_num_groups = num_groups
  568. def enable_activation_quantization(self, bits, quantization_type, range_calibration):
  569. assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
  570. self.activation_quantization_bits = bits
  571. self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
  572. if range_calibration == 'static':
  573. self.activation_quantizer = QuantAct(quant_mode=quantization_type)
  574. else:
  575. if quantization_type == 'symmetric':
  576. self.activation_quantizer = SymQuantizer.apply
  577. else:
  578. self.activation_quantizer = AsymQuantizer.apply
  579. def forward(self, input):
  580. if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
  581. weight = self.weight_quantizer(self.weight,
  582. self.weight.target_bits,
  583. None,
  584. None,
  585. self.weight_quantize_num_groups)
  586. bias = self.bias
  587. else:
  588. weight = self.weight
  589. bias = self.bias
  590. if self.sparse_pruning_enabled and self.sparse_pruning_method:
  591. mask = self.get_mask(pruning_type='sparse')
  592. weight = weight * mask.view(self.weight.size())
  593. if self.channel_pruning_enabled:
  594. mask = self.get_mask(pruning_type='channel')
  595. weight = weight * mask.view(-1, 1, 1, 1)
  596. if bias is not None:
  597. bias = bias * mask.view(-1)
  598. if self.activation_quantization_enabled:
  599. if 'dynamic' in self.activation_quantization_method:
  600. num_groups = input.numel() // input[0].numel()
  601. else:
  602. num_groups = 1
  603. input = self.activation_quantizer(input,
  604. self.activation_quantization_bits,
  605. None,
  606. None,
  607. num_groups)
  608. return nn.functional.conv2d(input,
  609. weight,
  610. bias,
  611. self.stride,
  612. self.padding,
  613. self.dilation,
  614. self.groups)
  615. class BNLayer_Compress(nn.BatchNorm2d):
  616. def fix_channel_pruning_helper(self, mask, dim_reduction=True):
  617. self.weight = nn.Parameter(self.weight.data[mask.view(-1)])
  618. self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
  619. self.running_mean = self.running_mean[mask.view(-1)]
  620. self.running_var = self.running_var[mask.view(-1)]
  621. def _reduce(input_):
  622. """All-reduce the the input tensor across model parallel group."""
  623. group = g_mpu.get_model_parallel_group()
  624. # Bypass the function if we are using only 1 GPU.
  625. if dist.get_world_size(group=group) == 1:
  626. return input_
  627. # All-reduce.
  628. dist.all_reduce(input_, group=group)
  629. return input_
  630. def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
  631. """Split a tensor along its last dimension.
  632. Arguments:
  633. tensor: input tensor.
  634. num_partitions: number of partitions to split the tensor
  635. contiguous_split_chunks: If True, make each chunk contiguous
  636. in memory.
  637. """
  638. # Get the size and dimension.
  639. last_dim = tensor.dim() - 1
  640. assert tensor.size()[last_dim] % num_partitions == 0
  641. last_dim_size = tensor.size()[last_dim] // num_partitions
  642. # Split.
  643. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  644. # Note: torch.split does not create contiguous tensors by default.
  645. if contiguous_split_chunks:
  646. return tuple(chunk.contiguous() for chunk in tensor_list)
  647. return tensor_list
  648. def _split(input_):
  649. """Split the tensor along its last dimension and keep the
  650. corresponding slice."""
  651. group = g_mpu.get_model_parallel_group()
  652. # Bypass the function if we are using only 1 GPU.
  653. if dist.get_world_size(group=group) == 1:
  654. return input_
  655. # Split along last dimension.
  656. world_size = dist.get_world_size(group=group)
  657. input_list = split_tensor_along_last_dim(input_, world_size)
  658. # Note: torch.split does not create contiguous tensors by default.
  659. rank = dist.get_rank(group=group)
  660. output = input_list[rank].contiguous()
  661. return output
  662. def _gather(input_):
  663. """Gather tensors and concatinate along the last dimension."""
  664. group = g_mpu.get_model_parallel_group()
  665. # Bypass the function if we are using only 1 GPU.
  666. if dist.get_world_size(group=group) == 1:
  667. return input_
  668. # Size and dimension.
  669. last_dim = input_.dim() - 1
  670. rank = dist.get_rank(group=group)
  671. world_size = dist.get_world_size(group=group)
  672. tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
  673. tensor_list[rank] = input_
  674. dist.all_gather(tensor_list, input_, group=group)
  675. # Note: torch.cat already creates a contiguous tensor.
  676. output = torch.cat(tensor_list, dim=last_dim).contiguous()
  677. return output
  678. class _CopyToModelParallelRegion(torch.autograd.Function):
  679. """Pass the input to the model parallel region."""
  680. @staticmethod
  681. def forward(ctx, input_):
  682. return input_
  683. @staticmethod
  684. def backward(ctx, grad_output):
  685. return _reduce(grad_output)
  686. class _ReduceFromModelParallelRegion(torch.autograd.Function):
  687. """All-redcue the input from the model parallel region."""
  688. @staticmethod
  689. def forward(ctx, input_):
  690. return _reduce(input_)
  691. @staticmethod
  692. def backward(ctx, grad_output):
  693. return grad_output
  694. class _ScatterToModelParallelRegion(torch.autograd.Function):
  695. """Split the input and keep only the corresponding chuck to the rank."""
  696. @staticmethod
  697. def forward(ctx, input_):
  698. return _split(input_)
  699. @staticmethod
  700. def backward(ctx, grad_output):
  701. return _gather(grad_output)
  702. class _GatherFromModelParallelRegion(torch.autograd.Function):
  703. """Gather the input from model parallel region and concatinate."""
  704. @staticmethod
  705. def forward(ctx, input_):
  706. return _gather(input_)
  707. @staticmethod
  708. def backward(ctx, grad_output):
  709. return _split(grad_output)
  710. # -----------------
  711. # Helper functions.
  712. # -----------------
  713. def copy_to_model_parallel_region(input_):
  714. return _CopyToModelParallelRegion.apply(input_)
  715. def reduce_from_model_parallel_region(input_):
  716. return _ReduceFromModelParallelRegion.apply(input_)
  717. def scatter_to_model_parallel_region(input_):
  718. return _ScatterToModelParallelRegion.apply(input_)
  719. def gather_from_model_parallel_region(input_):
  720. return _GatherFromModelParallelRegion.apply(input_)
  721. class ColumnParallelLinear_Compress(LinearLayer_Compress):
  722. def __init__(self,
  723. mpu,
  724. input_size,
  725. output_size,
  726. bias=True,
  727. gather_output=True,
  728. skip_bias_add=False):
  729. # Keep input parameters
  730. global g_mpu
  731. g_mpu = mpu
  732. self.input_size = input_size
  733. self.output_size = output_size
  734. self.gather_output = gather_output
  735. self.skip_bias_add = skip_bias_add
  736. # Divide the weight matrix along the last dimension.
  737. world_size = mpu.get_model_parallel_world_size()
  738. assert output_size % world_size == 0
  739. self.output_size_per_partition = output_size // world_size
  740. super(ColumnParallelLinear_Compress,
  741. self).__init__(self.input_size,
  742. self.output_size_per_partition,
  743. bias=bias)
  744. def forward(self, input_):
  745. # Set up backprop all-reduce.
  746. input_parallel = copy_to_model_parallel_region(input_)
  747. # Matrix multiply.
  748. if self.skip_bias_add:
  749. output_parallel, bias = super().forward(input_parallel, True)
  750. else:
  751. output_parallel = super().forward(input_parallel)
  752. bias = None
  753. if self.gather_output:
  754. # All-gather across the partitions.
  755. output = gather_from_model_parallel_region(output_parallel)
  756. else:
  757. output = output_parallel
  758. return output, bias
  759. class RowParallelLinear_Compress(LinearLayer_Compress):
  760. def __init__(self,
  761. mpu,
  762. input_size,
  763. output_size,
  764. bias=True,
  765. input_is_parallel=False,
  766. skip_bias_add=False):
  767. # Keep input parameters
  768. global g_mpu
  769. g_mpu = mpu
  770. self.input_size = input_size
  771. self.output_size = output_size
  772. self.input_is_parallel = input_is_parallel
  773. self.skip_bias_add = skip_bias_add
  774. # Divide the weight matrix along the last dimension.
  775. world_size = mpu.get_model_parallel_world_size()
  776. assert input_size % world_size == 0
  777. self.input_size_per_partition = input_size // world_size
  778. super(RowParallelLinear_Compress,
  779. self).__init__(self.input_size_per_partition,
  780. self.output_size,
  781. bias=bias)
  782. def forward(self, input_):
  783. # Set up backprop all-reduce.
  784. if self.input_is_parallel:
  785. input_parallel = input_
  786. else:
  787. input_parallel = scatter_to_model_parallel_region(input_)
  788. # Matrix multiply.
  789. output_parallel, bias = super().forward(input_parallel, True)
  790. # All-reduce across all the partitions.
  791. output_ = reduce_from_model_parallel_region(output_parallel)
  792. if not self.skip_bias_add:
  793. if bias is not None:
  794. output = output_ + bias
  795. else:
  796. output = output_
  797. output_bias = None
  798. else:
  799. output = output_
  800. output_bias = bias
  801. return output, output_bias