utils.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import torch
  3. from torch import autograd
  4. import math
  5. class TopKBinarizer(autograd.Function):
  6. """
  7. Top-k Binarizer.
  8. Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}`
  9. is among the k% highest values of S.
  10. Implementation is inspired from:
  11. https://github.com/yaozhewei/MLPruning
  12. """
  13. @staticmethod
  14. def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool):
  15. """
  16. Args:
  17. inputs (`torch.FloatTensor`)
  18. The input matrix from which the binarizer computes the binary mask.
  19. threshold (`float`)
  20. The percentage of weights to keep (the rest is pruned).
  21. `threshold` is a float between 0 and 1.
  22. sigmoid (`bool`)
  23. Whether to apply a sigmoid on the threshold
  24. Returns:
  25. mask (`torch.FloatTensor`)
  26. Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
  27. retained, 0 - the associated weight is pruned).
  28. """
  29. # Get the subnetwork by sorting the inputs and using the top threshold
  30. if sigmoid:
  31. threshold = torch.sigmoid(threshold).item()
  32. ctx.sigmoid = sigmoid
  33. mask = inputs.clone()
  34. _, idx = inputs.flatten().sort(descending=True)
  35. j = math.ceil(threshold * inputs.numel())
  36. # flat_out and mask access the same memory.
  37. flat_out = mask.flatten()
  38. flat_out[idx[j:]] = 0.
  39. flat_out[idx[:j]] = 1.
  40. ctx.save_for_backward(mask)
  41. return mask
  42. @staticmethod
  43. def backward(ctx, gradOutput):
  44. mask, = ctx.saved_tensors
  45. if ctx.sigmoid:
  46. return gradOutput.clone(), ((gradOutput * mask).sum()).view(-1), None
  47. else:
  48. return gradOutput.clone(), None, None
  49. class SymQuantizer(torch.autograd.Function):
  50. """
  51. Symmetric quantization
  52. """
  53. @staticmethod
  54. def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
  55. """
  56. Args:
  57. inputs (`torch.FloatTensor`)
  58. The input which needs to be quantized
  59. num_bits (int, >=4)
  60. Number of bits to use for quantization
  61. min_value/max_vlue (torch.FloatTensor)
  62. Used for static activation quantization
  63. num_groups (int)
  64. How many groups to partition the quantization into
  65. Returns:
  66. quantized_input (`torch.FloatTensor`)
  67. Quantized input
  68. """
  69. assert (min_value is None
  70. and max_value is None) or (min_value is not None
  71. and max_value is not None and num_groups == 1)
  72. q_range = 2**num_bits
  73. input_shape = input.shape
  74. if min_value is None:
  75. input = input.reshape(num_groups, -1)
  76. max_input = torch.amax(torch.abs(input), dim=-1).view(num_groups, -1)
  77. else:
  78. max_input = torch.max(min_value.abs(), max_value).view(-1)
  79. scale = 2 * max_input / q_range
  80. output = (input / scale).round().clamp(-q_range // 2, q_range // 2 - 1) * scale
  81. output = output.reshape(input_shape).contiguous()
  82. return output
  83. @staticmethod
  84. def backward(ctx, grad_output):
  85. grad_input = grad_output.clone()
  86. return grad_input, None, None, None, None
  87. class AsymQuantizer(torch.autograd.Function):
  88. """
  89. Asymmetric quantization
  90. """
  91. @staticmethod
  92. def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
  93. """
  94. Args:
  95. inputs (`torch.FloatTensor`)
  96. The input which needs to be quantized
  97. num_bits (int, >=4)
  98. Number of bits to use for quantization
  99. min_value/max_vlue (torch.FloatTensor)
  100. Used for static activation quantization
  101. num_groups (int)
  102. How many groups to partition the quantization into
  103. Returns:
  104. quantized_input (`torch.FloatTensor`)
  105. Quantized input
  106. """
  107. assert (min_value is None
  108. and max_value is None) or (min_value is not None
  109. and max_value is not None and num_groups == 1)
  110. q_range = 2**num_bits
  111. input_shape = input.shape
  112. if min_value is None:
  113. input = input.reshape(num_groups, -1)
  114. min_value = input.amin(dim=-1, keepdim=True)
  115. max_value = input.amax(dim=-1, keepdim=True)
  116. scale = (max_value - min_value) / q_range
  117. zero_point = (min_value / scale).round() * scale
  118. output = (
  119. (input - zero_point) / scale).round().clamp(0,
  120. q_range - 1) * scale + zero_point
  121. output = output.reshape(input_shape).contiguous()
  122. return output
  123. @staticmethod
  124. def backward(ctx, grad_output):
  125. grad_input = grad_output.clone()
  126. return grad_input, None, None, None, None
  127. class TernaryQuantizer(torch.autograd.Function):
  128. """
  129. Ternary quantization
  130. """
  131. @staticmethod
  132. def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
  133. """
  134. Args:
  135. inputs (`torch.FloatTensor`)
  136. The input which needs to be quantized
  137. num_bits (int)
  138. Dummy variable
  139. min_value/max_vlue (torch.FloatTensor)
  140. Used for static activation quantization; for now they are dummy variable
  141. num_groups (int)
  142. How many groups to partition the quantization into
  143. Returns:
  144. quantized_input (`torch.FloatTensor`)
  145. Quantized input
  146. """
  147. assert (min_value is None and max_value is None)
  148. input_flat = input.reshape(num_groups, -1)
  149. n = input_flat.shape[1]
  150. m = input_flat.norm(p=1, dim=1).div(n)
  151. thres = (0.7 * m).view(-1, 1)
  152. pos = (input_flat > thres).type(input.type())
  153. neg = (input_flat < -thres).type(input.type())
  154. mask = (input_flat.abs() > thres).type(input.type())
  155. alpha = ((mask * input_flat).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1)
  156. output = alpha * pos - alpha * neg
  157. output = output.reshape(input.shape).contiguous()
  158. return output
  159. @staticmethod
  160. def backward(ctx, grad_output):
  161. grad_input = grad_output.clone()
  162. return grad_input, None, None, None, None
  163. class BinaryQuantizer(torch.autograd.Function):
  164. """
  165. Binary quantization
  166. """
  167. @staticmethod
  168. def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
  169. """
  170. Args:
  171. inputs (`torch.FloatTensor`)
  172. The input which needs to be quantized
  173. num_bits (int)
  174. Dummy variable
  175. min_value/max_vlue (torch.FloatTensor)
  176. Used for static activation quantization; for now they are dummy variable
  177. num_groups (int)
  178. How many groups to partition the quantization into
  179. Returns:
  180. quantized_input (`torch.FloatTensor`)
  181. Quantized input
  182. """
  183. assert (min_value is None and max_value is None)
  184. input_flat = input.reshape(num_groups, -1)
  185. n = input_flat.shape[1]
  186. m = input_flat.norm(p=1, dim=1, keepdim=True).div(n)
  187. output = input_flat.sign().mul(m)
  188. output = output.reshape(input.shape).contiguous()
  189. return output
  190. @staticmethod
  191. def backward(ctx, grad_output):
  192. grad_input = grad_output.clone()
  193. return grad_input, None, None, None, None