123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- import math
- from torch import nn
- from torch.nn import init
- import deepspeed.comm as dist
- from .utils import TopKBinarizer, SymQuantizer, AsymQuantizer, TernaryQuantizer, BinaryQuantizer
- from deepspeed.utils import logger
- g_mpu = None
- class QuantAct(nn.Module):
- """
- Class to quantize given activations. Note that when using this function, the input activation quantization range will be fixed for all
- tokens/images for inference. This generally will affect some accuracy but achieve better latency performance.
- Parameters:
- ----------
- act_range_momentum : float, default 0.95
- Momentum for updating the activation quantization range.
- quant_mode : str, default 'symmetric'
- """
- def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'):
- super(QuantAct, self).__init__()
- self.act_range_momentum = act_range_momentum
- self.quant_mode = quant_mode
- if quant_mode == 'symmetric':
- self.act_function = SymQuantizer.apply
- else:
- self.act_function = AsymQuantizer.apply
- self.register_buffer('x_min_max', torch.zeros(2))
- def forward(self, x, num_bits, *args):
- """
- x: the activation that we need to quantize
- num_bits: the number of bits we need to quantize the activation to
- *args: some extra arguments that are useless but needed for align with the interface of other quantization functions
- """
- if self.training:
- x_min = x.data.min()
- x_max = x.data.max()
- # Initialization
- if self.x_min_max[0] == self.x_min_max[1]:
- self.x_min_max[0] = x_min
- self.x_min_max[1] = x_max
- # if do not need momentum, please set self.act_range_momentum = 0
- self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
- self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
- x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1])
- return x_q
- class Embedding_Compress(nn.Embedding):
- def __init__(self, *kargs):
- super(Embedding_Compress, self).__init__(*kargs)
- self.weight.start_bits = None
- self.weight.target_bits = None
- self.weight.q_period = None
- self.weight_quantization_enabled_in_forward = False
- self.weight_quantization_enabled = False
- def extra_repr(self):
- return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format(
- self.num_embeddings, self.embedding_dim, self.weight.target_bits)
- def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
- weight_quantization_enabled_in_forward, quantization_type, num_groups):
- self.weight.start_bits = start_bits
- self.weight.target_bits = target_bits
- self.weight.q_period = quantization_period
- self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
- if self.weight_quantization_enabled_in_forward:
- logger.warning(
- "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
- )
- if self.weight.target_bits >= 3:
- if quantization_type == 'symmetric':
- self.weight_quantizer = SymQuantizer.apply
- else:
- self.weight_quantizer = AsymQuantizer.apply
- elif self.weight.target_bits == 2:
- assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
- self.weight_quantizer = TernaryQuantizer.apply
- elif self.weight.target_bits == 1:
- assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
- self.weight_quantizer = BinaryQuantizer.apply
- # for embedding, we always use token-wise quantization
- self.weight_quantize_num_groups = self.weight.size(0)
- def fix_weight_quantization(self):
- self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
- self.weight_quantize_num_groups).data
- self.weight_quantization_enabled_in_forward = False
- return None
- def forward(self, input):
- if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
- weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
- self.weight_quantize_num_groups)
- else:
- weight = self.weight
- out = nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type,
- self.scale_grad_by_freq, self.sparse)
- return out
- class LinearLayer_Compress(nn.Linear):
- """
- Linear layer with compression.
- """
- def __init__(self, *kargs, bias=True):
- super(LinearLayer_Compress, self).__init__(*kargs, bias=bias)
- self.sparse_pruning_method = None
- self.row_pruning_method = None
- self.head_pruning_method = None
- self.activation_quantization_method = None
- self.weight.start_bits = None
- self.weight.target_bits = None
- self.weight.q_period = None
- self.weight_quantization_enabled_in_forward = False
- self.weight_quantization_enabled = False
- self.sparse_pruning_enabled = False
- self.row_pruning_enabled = False
- self.head_pruning_enabled = False
- self.activation_quantization_enabled = False
- def extra_repr(self):
- return 'in_features={}, out_features={}, bias={}, sparse pruning={}, row pruning={}, head pruning={}, activation quantization={}, weight_quantization={}'.format(
- self.in_features, self.out_features, self.bias is not None, self.sparse_pruning_method is not None, \
- self.row_pruning_method is not None, self.head_pruning_method is not None, self.activation_quantization_method is not None, self.weight.target_bits)
- def enable_sparse_pruning(self, ratio, method):
- # Here, we support two cases: L1 norm based pruning and topk based pruning
- self.sparse_pruning_ratio = ratio
- self.sparse_pruning_method = method
- if method == 'l1':
- weight_norm = torch.abs(self.weight.data)
- mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
- mask = mask.view(self.weight.size())
- mask = mask.to(self.weight.device)
- elif method == 'topk':
- self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
- self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
- init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
- mask = None
- else:
- raise NotImplementedError
- self.register_buffer('sparse_pruning_mask', mask)
- def enable_row_pruning(self, ratio, method):
- # Here, we support two cases: L1 norm based pruning and topk based pruning
- self.row_pruning_ratio = ratio
- self.row_pruning_method = method
- if method == 'l1':
- # compute the l1 norm of each column
- weight_norm = torch.linalg.norm(self.weight.data, ord=1, dim=1)
- mask = TopKBinarizer.apply(weight_norm, self.row_pruning_ratio, False)
- mask = mask.view(-1, 1)
- mask = mask.to(self.weight.device)
- elif method == 'topk':
- self.row_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1))
- self.row_mask_scores.data = self.row_mask_scores.data.to(self.weight.device)
- init.kaiming_uniform_(self.row_mask_scores, a=math.sqrt(5))
- mask = None
- else:
- raise NotImplementedError
- self.register_buffer('row_pruning_mask', mask)
- def enable_head_pruning(self, ratio, method, num_heads):
- # Here, we support only topk based pruning
- self.num_heads = num_heads
- self.head_pruning_ratio = ratio
- self.head_pruning_method = method
- if method not in ['topk']:
- raise NotImplementedError
- else:
- self.head_pruning_ratio = ratio
- self.head_pruning_scores = nn.Parameter(torch.Tensor(1,
- self.num_heads)) # we apply the pruning to O matrix
- self.head_pruning_scores.data = self.head_pruning_scores.data.to(self.weight.device)
- init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5))
- def fix_sparse_pruning_helper(self):
- mask = self.get_mask(pruning_type='sparse')
- self.weight.data = self.weight.data * mask
- del self.sparse_pruning_mask
- if self.sparse_pruning_method == 'topk':
- del self.sparse_mask_scores
- self.sparse_pruning_method = None
- self.sparse_pruning_enabled = False
- return None
- def fix_row_col_pruning_helper(self, mask=None, dim_reduction=False):
- # This function is used for row/col pruning
- # particularly, if we have two back-to-back layers, F1 and F2; when
- # we remove rows from F1, we also need to remove columns from F2
- # However, if we only have one layer, F1, then we only need to mask pruned
- # rows as 0 in F1
- if mask is None:
- mask = self.get_mask(pruning_type='row').bool()
- if dim_reduction:
- start_bits = self.weight.start_bits
- target_bits = self.weight.target_bits
- q_period = self.weight.q_period
- self.weight = nn.Parameter(self.weight.data[mask.view(-1), :])
- self.weight.start_bits = start_bits
- self.weight.target_bits = target_bits
- self.weight.q_period = q_period
- if self.bias is not None:
- self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
- self.out_features = self.weight.size(0)
- else:
- self.weight.data = self.weight.data * mask.view(-1, 1)
- if self.bias is not None:
- self.bias.data = self.bias.data * mask.view(-1)
- del self.row_pruning_mask
- if self.row_pruning_method == 'topk':
- del self.row_mask_scores
- self.row_pruning_method = None
- else:
- # this is generally for column pruning
- start_bits = self.weight.start_bits
- target_bits = self.weight.target_bits
- q_period = self.weight.q_period
- self.weight = nn.Parameter(self.weight.data[:, mask.view(-1)])
- self.weight.start_bits = start_bits
- self.weight.target_bits = target_bits
- self.weight.q_period = q_period
- self.in_features = self.weight.size(1)
- mask = None
- self.row_pruning_enabled = False
- return mask
- def fix_head_pruning_helper(self, mask=None, num_heads=None, dim_reduction=False):
- # similar as row/col pruning, head pruning also needs to prune QKV which is associated with O matrix
- num_heads = num_heads if num_heads else self.num_heads
- if mask is None:
- if self.head_pruning_method == 'topk':
- mask = self.get_mask(pruning_type='head').bool()
- if dim_reduction:
- shape = self.weight.size(0)
- start_bits = self.weight.start_bits
- target_bits = self.weight.target_bits
- q_period = self.weight.q_period
- self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads,
- -1)[mask.view(-1), :].reshape(-1,
- shape).t())
- self.weight.start_bits = start_bits
- self.weight.target_bits = target_bits
- self.weight.q_period = q_period
- else:
- shape = self.weight.size()
- self.weight.data = (self.weight.data.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(
- shape[1], shape[0]).t()
- if self.head_pruning_method == 'topk':
- del self.head_pruning_scores
- self.head_pruning_method = None
- else:
- raise NotImplementedError
- else:
- start_bits = self.weight.start_bits
- target_bits = self.weight.target_bits
- q_period = self.weight.q_period
- shape = self.weight.size(1)
- self.weight = nn.Parameter(self.weight.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape))
- self.weight.start_bits = start_bits
- self.weight.target_bits = target_bits
- self.weight.q_period = q_period
- if self.bias is not None:
- self.bias = nn.Parameter(self.bias.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1))
- self.head_pruning_enabled = False
- return mask
- def get_mask(self, pruning_type='row'):
- if pruning_type == 'sparse':
- if self.sparse_pruning_method == 'l1':
- return self.sparse_pruning_mask.to(self.weight.device)
- elif self.sparse_pruning_method == 'topk':
- return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
- else:
- raise NotImplementedError
- if pruning_type == 'row':
- if self.row_pruning_method == 'l1':
- return self.row_pruning_mask.to(self.weight.device)
- elif self.row_pruning_method == 'topk':
- return TopKBinarizer.apply(self.row_mask_scores, self.row_pruning_ratio, False)
- else:
- raise NotImplementedError
- elif pruning_type == 'head':
- if self.head_pruning_method == 'topk':
- return TopKBinarizer.apply(self.head_pruning_scores, self.head_pruning_ratio, False)
- else:
- raise NotImplementedError
- else:
- raise NotImplementedError
- def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
- weight_quantization_enabled_in_forward, quantization_type, num_groups):
- self.weight.start_bits = start_bits
- self.weight.target_bits = target_bits
- self.weight.q_period = quantization_period
- self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
- if self.weight_quantization_enabled_in_forward:
- logger.warning(
- "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
- )
- if self.weight.target_bits >= 3:
- if quantization_type == 'symmetric':
- self.weight_quantizer = SymQuantizer.apply
- else:
- self.weight_quantizer = AsymQuantizer.apply
- elif self.weight.target_bits == 2:
- assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
- self.weight_quantizer = TernaryQuantizer.apply
- elif self.weight.target_bits == 1:
- assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
- self.weight_quantizer = BinaryQuantizer.apply
- self.weight_quantize_num_groups = num_groups
- def fix_weight_quantization(self):
- self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
- self.weight_quantize_num_groups).data
- self.weight_quantization_enabled_in_forward = False
- return None
- def enable_activation_quantization(self, bits, quantization_type, range_calibration):
- assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
- self.activation_quantization_bits = bits
- self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
- if range_calibration == 'static':
- self.activation_quantizer = QuantAct(quant_mode=quantization_type)
- else:
- if quantization_type == 'symmetric':
- self.activation_quantizer = SymQuantizer.apply
- else:
- self.activation_quantizer = AsymQuantizer.apply
- def head_pruning_reshape(self, w, mask):
- shape = w.shape
- return (w.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(shape[1], shape[0]).t()
- def forward(self, input, skip_bias_add=False):
- if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
- weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
- self.weight_quantize_num_groups)
- bias = self.bias
- else:
- weight = self.weight
- bias = self.bias
- if self.sparse_pruning_enabled and self.sparse_pruning_method:
- mask = self.get_mask(pruning_type='sparse')
- weight = weight * mask.view(self.weight.size())
- if self.row_pruning_enabled and self.row_pruning_method:
- mask = self.get_mask(pruning_type='row')
- weight = weight * mask.view(-1, 1)
- if bias is not None:
- bias = bias * mask.view(-1)
- if self.head_pruning_enabled and self.head_pruning_method:
- mask = self.get_mask(pruning_type='head')
- weight = self.head_pruning_reshape(weight, mask)
- if self.activation_quantization_enabled:
- if 'dynamic' in self.activation_quantization_method:
- num_groups = input.numel() // input.size(-1)
- else:
- num_groups = 1
- input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
- if skip_bias_add:
- # used for mpu linear layers
- output = nn.functional.linear(input, weight, None)
- return output, bias
- else:
- output = nn.functional.linear(input, weight, bias)
- return output
- class Conv2dLayer_Compress(nn.Conv2d):
- """
- Conv2D layer with compression.
- """
- def __init__(self, *kargs):
- super(Conv2dLayer_Compress, self).__init__(*kargs)
- self.sparse_pruning_method = None
- self.channel_pruning_method = None
- self.activation_quantization_method = None
- self.weight.start_bits = None
- self.weight.target_bits = None
- self.weight.q_period = None
- self.weight_quantization_enabled_in_forward = False
- self.sparse_pruning_enabled = False
- self.channel_pruning_enabled = False
- self.activation_quantization_enabled = False
- def __repr__(self):
- s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
- ', stride={stride}')
- if self.padding != (0, ) * len(self.padding):
- s += ', padding={padding}'
- if self.dilation != (1, ) * len(self.dilation):
- s += ', dilation={dilation}'
- if self.output_padding != (0, ) * len(self.output_padding):
- s += ', output_padding={output_padding}'
- if self.groups != 1:
- s += ', groups={groups}'
- if self.bias is None:
- s += ', bias=False'
- if self.padding_mode != 'zeros':
- s += ', padding_mode={padding_mode}'
- output = s.format(**self.__dict__)
- return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format(
- self.sparse_pruning_method is not None, self.channel_pruning_method is not None,
- self.activation_quantization_method is not None, self.weight.target_bits)
- def enable_sparse_pruning(self, ratio, method):
- self.sparse_pruning_ratio = ratio
- self.sparse_pruning_method = method
- if method == 'l1':
- weight_norm = torch.abs(self.weight.data)
- mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
- mask = mask.view(self.weight.size())
- mask = mask.to(self.weight.device)
- elif method == 'topk':
- self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
- self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
- init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
- mask = None
- else:
- raise NotImplementedError
- self.register_buffer('sparse_pruning_mask', mask)
- def enable_channel_pruning(self, ratio, method):
- # Here, we support two cases: L1 norm based pruning and topk based pruning
- self.channel_pruning_ratio = ratio
- self.channel_pruning_method = method
- if method == 'l1':
- # compute the l1 norm of each conv2d kernel (the last three dimension)
- weight_norm = torch.linalg.norm(self.weight.data, ord=1, dim=[1, 2, 3])
- mask = TopKBinarizer.apply(weight_norm, self.channel_pruning_ratio, False)
- mask = mask.view(-1, 1, 1, 1)
- mask = mask.to(self.weight.device)
- elif method == 'topk':
- self.channel_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1, 1, 1))
- self.channel_mask_scores.data = self.channel_mask_scores.data.to(self.weight.device)
- init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5))
- mask = None
- else:
- raise NotImplementedError
- self.register_buffer('channel_pruning_mask', mask)
- def fix_sparse_pruning_helper(self):
- mask = self.get_mask(pruning_type='sparse')
- self.weight.data = self.weight.data * mask
- del self.sparse_pruning_mask
- if self.sparse_pruning_method == 'topk':
- del self.sparse_mask_scores
- self.sparse_pruning_method = None
- self.sparse_pruning_enabled = False
- return None
- def fix_channel_pruning_helper(self, mask=None, dim_reduction=False):
- if mask is None:
- if self.channel_pruning_method in ['l1', 'topk']:
- mask = self.get_mask(pruning_type='channel').bool()
- if dim_reduction:
- start_bits = self.weight.start_bits
- target_bits = self.weight.target_bits
- q_period = self.weight.q_period
- self.weight = nn.Parameter(self.weight.data[mask.view(-1), ...])
- self.weight.start_bits = start_bits
- self.weight.target_bits = target_bits
- self.weight.q_period = q_period
- if self.bias is not None:
- self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
- else:
- self.weight.data = self.weight.data * mask.view(-1, 1, 1, 1)
- if self.bias is not None:
- self.bias.data = self.bias.data * mask.view(-1)
- del self.channel_pruning_mask
- if self.channel_pruning_method == 'topk':
- del self.channel_mask_scores
- self.channel_pruning_method = None
- else:
- raise NotImplementedError
- else:
- start_bits = self.weight.start_bits
- target_bits = self.weight.target_bits
- q_period = self.weight.q_period
- self.weight = nn.Parameter(self.weight.data[:, mask.view(-1), ...])
- self.weight.start_bits = start_bits
- self.weight.target_bits = target_bits
- self.weight.q_period = q_period
- mask = None
- self.channel_pruning_enabled = False
- return mask
- def get_mask(self, pruning_type='sparse'):
- if pruning_type == 'sparse':
- if self.sparse_pruning_method == 'l1':
- return self.sparse_pruning_mask.to(self.weight.device)
- elif self.sparse_pruning_method == 'topk':
- return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
- else:
- raise NotImplementedError
- elif pruning_type == 'channel':
- if self.channel_pruning_method == 'l1':
- return self.channel_pruning_mask.to(self.weight.device)
- elif self.channel_pruning_method == 'topk':
- return TopKBinarizer.apply(self.channel_mask_scores, self.channel_pruning_ratio, False)
- else:
- raise NotImplementedError
- else:
- raise NotImplementedError
- def fix_weight_quantization(self):
- self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
- self.weight_quantize_num_groups).data
- self.weight_quantization_enabled_in_forward = False
- return None
- def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
- weight_quantization_enabled_in_forward, quantization_type, num_groups):
- self.weight.start_bits = start_bits
- self.weight.target_bits = target_bits
- self.weight.q_period = quantization_period
- self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
- if self.weight_quantization_enabled_in_forward:
- assert self.weight.target_bits >= 4, 'Only >=4 bits weight quantization are supported during forward pass for now'
- logger.warning(
- "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
- )
- if quantization_type == 'symmetric':
- self.weight_quantizer = SymQuantizer.apply
- else:
- self.weight_quantizer = AsymQuantizer.apply
- self.weight_quantize_num_groups = num_groups
- def enable_activation_quantization(self, bits, quantization_type, range_calibration):
- assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
- self.activation_quantization_bits = bits
- self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
- if range_calibration == 'static':
- self.activation_quantizer = QuantAct(quant_mode=quantization_type)
- else:
- if quantization_type == 'symmetric':
- self.activation_quantizer = SymQuantizer.apply
- else:
- self.activation_quantizer = AsymQuantizer.apply
- def forward(self, input):
- if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
- weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
- self.weight_quantize_num_groups)
- bias = self.bias
- else:
- weight = self.weight
- bias = self.bias
- if self.sparse_pruning_enabled and self.sparse_pruning_method:
- mask = self.get_mask(pruning_type='sparse')
- weight = weight * mask.view(self.weight.size())
- if self.channel_pruning_enabled:
- mask = self.get_mask(pruning_type='channel')
- weight = weight * mask.view(-1, 1, 1, 1)
- if bias is not None:
- bias = bias * mask.view(-1)
- if self.activation_quantization_enabled:
- if 'dynamic' in self.activation_quantization_method:
- num_groups = input.numel() // input[0].numel()
- else:
- num_groups = 1
- input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
- return nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
- class BNLayer_Compress(nn.BatchNorm2d):
- def fix_channel_pruning_helper(self, mask, dim_reduction=True):
- self.weight = nn.Parameter(self.weight.data[mask.view(-1)])
- self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
- self.running_mean = self.running_mean[mask.view(-1)]
- self.running_var = self.running_var[mask.view(-1)]
- def _reduce(input_):
- """All-reduce the input tensor across model parallel group."""
- group = g_mpu.get_model_parallel_group()
- # Bypass the function if we are using only 1 GPU.
- if dist.get_world_size(group=group) == 1:
- return input_
- # All-reduce.
- dist.all_reduce(input_, group=group)
- return input_
- def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
- """Split a tensor along its last dimension.
- Arguments:
- tensor: input tensor.
- num_partitions: number of partitions to split the tensor
- contiguous_split_chunks: If True, make each chunk contiguous
- in memory.
- """
- # Get the size and dimension.
- last_dim = tensor.dim() - 1
- assert tensor.size()[last_dim] % num_partitions == 0
- last_dim_size = tensor.size()[last_dim] // num_partitions
- # Split.
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
- # Note: torch.split does not create contiguous tensors by default.
- if contiguous_split_chunks:
- return tuple(chunk.contiguous() for chunk in tensor_list)
- return tensor_list
- def _split(input_):
- """Split the tensor along its last dimension and keep the
- corresponding slice."""
- group = g_mpu.get_model_parallel_group()
- # Bypass the function if we are using only 1 GPU.
- if dist.get_world_size(group=group) == 1:
- return input_
- # Split along last dimension.
- world_size = dist.get_world_size(group=group)
- input_list = split_tensor_along_last_dim(input_, world_size)
- # Note: torch.split does not create contiguous tensors by default.
- rank = dist.get_rank(group=group)
- output = input_list[rank].contiguous()
- return output
- def _gather(input_):
- """Gather tensors and concatenate along the last dimension."""
- group = g_mpu.get_model_parallel_group()
- # Bypass the function if we are using only 1 GPU.
- if dist.get_world_size(group=group) == 1:
- return input_
- # Size and dimension.
- last_dim = input_.dim() - 1
- rank = dist.get_rank(group=group)
- world_size = dist.get_world_size(group=group)
- tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
- tensor_list[rank] = input_
- dist.all_gather(tensor_list, input_, group=group)
- # Note: torch.cat already creates a contiguous tensor.
- output = torch.cat(tensor_list, dim=last_dim).contiguous()
- return output
- class _CopyToModelParallelRegion(torch.autograd.Function):
- """Pass the input to the model parallel region."""
- @staticmethod
- def forward(ctx, input_):
- return input_
- @staticmethod
- def backward(ctx, grad_output):
- return _reduce(grad_output)
- class _ReduceFromModelParallelRegion(torch.autograd.Function):
- """All-reduce the input from the model parallel region."""
- @staticmethod
- def forward(ctx, input_):
- return _reduce(input_)
- @staticmethod
- def backward(ctx, grad_output):
- return grad_output
- class _ScatterToModelParallelRegion(torch.autograd.Function):
- """Split the input and keep only the corresponding chuck to the rank."""
- @staticmethod
- def forward(ctx, input_):
- return _split(input_)
- @staticmethod
- def backward(ctx, grad_output):
- return _gather(grad_output)
- class _GatherFromModelParallelRegion(torch.autograd.Function):
- """Gather the input from model parallel region and concatenate."""
- @staticmethod
- def forward(ctx, input_):
- return _gather(input_)
- @staticmethod
- def backward(ctx, grad_output):
- return _split(grad_output)
- # -----------------
- # Helper functions.
- # -----------------
- def copy_to_model_parallel_region(input_):
- return _CopyToModelParallelRegion.apply(input_)
- def reduce_from_model_parallel_region(input_):
- return _ReduceFromModelParallelRegion.apply(input_)
- def scatter_to_model_parallel_region(input_):
- return _ScatterToModelParallelRegion.apply(input_)
- def gather_from_model_parallel_region(input_):
- return _GatherFromModelParallelRegion.apply(input_)
- class ColumnParallelLinear_Compress(LinearLayer_Compress):
- def __init__(self, mpu, input_size, output_size, bias=True, gather_output=True, skip_bias_add=False):
- # Keep input parameters
- global g_mpu
- g_mpu = mpu
- self.input_size = input_size
- self.output_size = output_size
- self.gather_output = gather_output
- self.skip_bias_add = skip_bias_add
- # Divide the weight matrix along the last dimension.
- world_size = mpu.get_model_parallel_world_size()
- assert output_size % world_size == 0
- self.output_size_per_partition = output_size // world_size
- super(ColumnParallelLinear_Compress, self).__init__(self.input_size, self.output_size_per_partition, bias=bias)
- def forward(self, input_):
- # Set up backprop all-reduce.
- input_parallel = copy_to_model_parallel_region(input_)
- # Matrix multiply.
- if self.skip_bias_add:
- output_parallel, bias = super().forward(input_parallel, True)
- else:
- output_parallel = super().forward(input_parallel)
- bias = None
- if self.gather_output:
- # All-gather across the partitions.
- output = gather_from_model_parallel_region(output_parallel)
- else:
- output = output_parallel
- return output, bias
- class RowParallelLinear_Compress(LinearLayer_Compress):
- def __init__(self, mpu, input_size, output_size, bias=True, input_is_parallel=False, skip_bias_add=False):
- # Keep input parameters
- global g_mpu
- g_mpu = mpu
- self.input_size = input_size
- self.output_size = output_size
- self.input_is_parallel = input_is_parallel
- self.skip_bias_add = skip_bias_add
- # Divide the weight matrix along the last dimension.
- world_size = mpu.get_model_parallel_world_size()
- assert input_size % world_size == 0
- self.input_size_per_partition = input_size // world_size
- super(RowParallelLinear_Compress, self).__init__(self.input_size_per_partition, self.output_size, bias=bias)
- def forward(self, input_):
- # Set up backprop all-reduce.
- if self.input_is_parallel:
- input_parallel = input_
- else:
- input_parallel = scatter_to_model_parallel_region(input_)
- # Matrix multiply.
- output_parallel, bias = super().forward(input_parallel, True)
- # All-reduce across all the partitions.
- output_ = reduce_from_model_parallel_region(output_parallel)
- if not self.skip_bias_add:
- if bias is not None:
- output = output_ + bias
- else:
- output = output_
- output_bias = None
- else:
- output = output_
- output_bias = bias
- return output, output_bias
|