utils.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from torch import autograd
  6. import math
  7. class TopKBinarizer(autograd.Function):
  8. """
  9. Top-k Binarizer.
  10. Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}`
  11. is among the k% highest values of S.
  12. Implementation is inspired from:
  13. https://github.com/yaozhewei/MLPruning
  14. """
  15. @staticmethod
  16. def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool):
  17. """
  18. Args:
  19. inputs (`torch.FloatTensor`)
  20. The input matrix from which the binarizer computes the binary mask.
  21. threshold (`float`)
  22. The percentage of weights to keep (the rest is pruned).
  23. `threshold` is a float between 0 and 1.
  24. sigmoid (`bool`)
  25. Whether to apply a sigmoid on the threshold
  26. Returns:
  27. mask (`torch.FloatTensor`)
  28. Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
  29. retained, 0 - the associated weight is pruned).
  30. """
  31. # Get the subnetwork by sorting the inputs and using the top threshold
  32. if sigmoid:
  33. threshold = torch.sigmoid(threshold).item()
  34. ctx.sigmoid = sigmoid
  35. mask = inputs.clone()
  36. _, idx = inputs.flatten().sort(descending=True)
  37. j = math.ceil(threshold * inputs.numel())
  38. # flat_out and mask access the same memory.
  39. flat_out = mask.flatten()
  40. flat_out[idx[j:]] = 0.
  41. flat_out[idx[:j]] = 1.
  42. ctx.save_for_backward(mask)
  43. return mask
  44. @staticmethod
  45. def backward(ctx, gradOutput):
  46. mask, = ctx.saved_tensors
  47. if ctx.sigmoid:
  48. return gradOutput.clone(), ((gradOutput * mask).sum()).view(-1), None
  49. else:
  50. return gradOutput.clone(), None, None
  51. class SymQuantizer(torch.autograd.Function):
  52. """
  53. Symmetric quantization
  54. """
  55. @staticmethod
  56. def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
  57. """
  58. Args:
  59. inputs (`torch.FloatTensor`)
  60. The input which needs to be quantized
  61. num_bits (int, >=4)
  62. Number of bits to use for quantization
  63. min_value/max_value (torch.FloatTensor)
  64. Used for static activation quantization
  65. num_groups (int)
  66. How many groups to partition the quantization into
  67. Returns:
  68. quantized_input (`torch.FloatTensor`)
  69. Quantized input
  70. """
  71. assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None
  72. and num_groups == 1)
  73. q_range = 2**num_bits
  74. input_shape = input.shape
  75. if min_value is None:
  76. input = input.reshape(num_groups, -1)
  77. max_input = torch.amax(torch.abs(input), dim=-1).view(num_groups, -1)
  78. else:
  79. max_input = torch.max(min_value.abs(), max_value).view(-1)
  80. scale = 2 * max_input / q_range
  81. output = (input / scale).round().clamp(-q_range // 2, q_range // 2 - 1) * scale
  82. output = output.reshape(input_shape).contiguous()
  83. return output
  84. @staticmethod
  85. def backward(ctx, grad_output):
  86. grad_input = grad_output.clone()
  87. return grad_input, None, None, None, None
  88. class AsymQuantizer(torch.autograd.Function):
  89. """
  90. Asymmetric quantization
  91. """
  92. @staticmethod
  93. def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
  94. """
  95. Args:
  96. inputs (`torch.FloatTensor`)
  97. The input which needs to be quantized
  98. num_bits (int, >=4)
  99. Number of bits to use for quantization
  100. min_value/max_value (torch.FloatTensor)
  101. Used for static activation quantization
  102. num_groups (int)
  103. How many groups to partition the quantization into
  104. Returns:
  105. quantized_input (`torch.FloatTensor`)
  106. Quantized input
  107. """
  108. assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None
  109. and num_groups == 1)
  110. q_range = 2**num_bits
  111. input_shape = input.shape
  112. if min_value is None:
  113. input = input.reshape(num_groups, -1)
  114. min_value = input.amin(dim=-1, keepdim=True)
  115. max_value = input.amax(dim=-1, keepdim=True)
  116. scale = (max_value - min_value) / q_range
  117. zero_point = (min_value / scale).round() * scale
  118. output = ((input - zero_point) / scale).round().clamp(0, q_range - 1) * scale + zero_point
  119. output = output.reshape(input_shape).contiguous()
  120. return output
  121. @staticmethod
  122. def backward(ctx, grad_output):
  123. grad_input = grad_output.clone()
  124. return grad_input, None, None, None, None
  125. class TernaryQuantizer(torch.autograd.Function):
  126. """
  127. Ternary quantization
  128. """
  129. @staticmethod
  130. def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
  131. """
  132. Args:
  133. inputs (`torch.FloatTensor`)
  134. The input which needs to be quantized
  135. num_bits (int)
  136. Dummy variable
  137. min_value/max_value (torch.FloatTensor)
  138. Used for static activation quantization; for now they are dummy variable
  139. num_groups (int)
  140. How many groups to partition the quantization into
  141. Returns:
  142. quantized_input (`torch.FloatTensor`)
  143. Quantized input
  144. """
  145. assert (min_value is None and max_value is None)
  146. input_flat = input.reshape(num_groups, -1)
  147. n = input_flat.shape[1]
  148. m = input_flat.norm(p=1, dim=1).div(n)
  149. thres = (0.7 * m).view(-1, 1)
  150. pos = (input_flat > thres).type(input.type())
  151. neg = (input_flat < -thres).type(input.type())
  152. mask = (input_flat.abs() > thres).type(input.type())
  153. alpha = ((mask * input_flat).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1)
  154. output = alpha * pos - alpha * neg
  155. output = output.reshape(input.shape).contiguous()
  156. return output
  157. @staticmethod
  158. def backward(ctx, grad_output):
  159. grad_input = grad_output.clone()
  160. return grad_input, None, None, None, None
  161. class BinaryQuantizer(torch.autograd.Function):
  162. """
  163. Binary quantization
  164. """
  165. @staticmethod
  166. def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
  167. """
  168. Args:
  169. inputs (`torch.FloatTensor`)
  170. The input which needs to be quantized
  171. num_bits (int)
  172. Dummy variable
  173. min_value/max_value (torch.FloatTensor)
  174. Used for static activation quantization; for now they are dummy variable
  175. num_groups (int)
  176. How many groups to partition the quantization into
  177. Returns:
  178. quantized_input (`torch.FloatTensor`)
  179. Quantized input
  180. """
  181. assert (min_value is None and max_value is None)
  182. input_flat = input.reshape(num_groups, -1)
  183. n = input_flat.shape[1]
  184. m = input_flat.norm(p=1, dim=1, keepdim=True).div(n)
  185. output = input_flat.sign().mul(m)
  186. output = output.reshape(input.shape).contiguous()
  187. return output
  188. @staticmethod
  189. def backward(ctx, grad_output):
  190. grad_input = grad_output.clone()
  191. return grad_input, None, None, None, None