123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- '''Copyright The Microsoft DeepSpeed Team'''
- import torch
- from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
- from .constants import *
- def recursive_getattr(model, module_name):
- """
- Recursively get the attribute of a module.
- Args:
- model (`torch.nn.Module`)
- The model to get the attribute from.
- module_name (`str`)
- The name of the module to get the attribute from.
- """
- split_list = module_name.split('.')
- output = model
- for name in split_list:
- output = getattr(output, name)
- return output
- def recursive_setattr(model, module_name, module):
- """
- Recursively set the attribute of a module.
- Args:
- model (`torch.nn.Module`)
- The model to set the attribute in.
- module_name (`str`)
- The name of the module to set the attribute in.
- module (`torch.nn.Module`)
- The module to set the attribute to.
- """
- split_list = module_name.split('.')
- output = model
- for name in split_list[:-1]:
- output = getattr(output, name)
- output.__setattr__(split_list[-1], module)
- def module_replacement(model, module_name, compression_technique=None, mpu=None):
- """
- Replace a module with a new module.
- Args:
- model (`torch.nn.Module`)
- The model to replace the module in.
- module_name (`str`)
- The name of the module to replace.
- compression_technique (`str`)
- The compression technique to use for the new module.
- """
- # Get the old module
- old_module = recursive_getattr(model, module_name)
- need_bias = False
- if hasattr(old_module, 'bias') and old_module.bias is not None:
- need_bias = True
- # Initialize the new module
- if isinstance(old_module,
- LinearLayer_Compress) or isinstance(old_module,
- torch.nn.Linear):
- if isinstance(old_module, LinearLayer_Compress):
- new_module = old_module
- else:
- new_module = LinearLayer_Compress(old_module.in_features,
- old_module.out_features,
- bias=need_bias).to(
- device=old_module.weight.device,
- dtype=old_module.weight.dtype)
- new_module.weight.data = old_module.weight.data
- if need_bias:
- new_module.bias.data = old_module.bias.data
- elif isinstance(old_module,
- Conv2dLayer_Compress) or isinstance(old_module,
- torch.nn.Conv2d):
- if isinstance(old_module, Conv2dLayer_Compress):
- new_module = old_module
- else:
- new_module = Conv2dLayer_Compress(old_module.in_channels, old_module.out_channels, old_module.kernel_size, old_module.stride, old_module.padding, \
- old_module.dilation, old_module.groups, need_bias, \
- old_module.padding_mode).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
- new_module.weight.data = old_module.weight.data
- if need_bias:
- new_module.bias.data = old_module.bias.data
- elif isinstance(old_module, torch.nn.BatchNorm2d):
- new_module = BNLayer_Compress(old_module.num_features,
- old_module.eps,
- old_module.momentum,
- old_module.affine,
- old_module.track_running_stats).to(
- old_module.weight.device,
- old_module.weight.dtype)
- new_module.weight.data = old_module.weight.data
- if need_bias:
- new_module.bias.data = old_module.bias.data
- new_module.running_mean.data = old_module.running_mean.data
- new_module.running_var.data = old_module.running_var.data
- elif isinstance(old_module,
- Embedding_Compress) or isinstance(old_module,
- torch.nn.Embedding):
- if isinstance(old_module, Embedding_Compress):
- new_module = old_module
- else:
- new_module = Embedding_Compress(old_module.num_embeddings, old_module.embedding_dim, old_module.padding_idx, old_module.max_norm, old_module.norm_type, \
- old_module.scale_grad_by_freq, old_module.sparse).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
- new_module.weight.data = old_module.weight.data
- elif mpu is not None and (isinstance(old_module,
- ColumnParallelLinear_Compress)
- or isinstance(old_module,
- mpu.ColumnParallelLinear)):
- if isinstance(old_module, ColumnParallelLinear_Compress):
- new_module = old_module
- else:
- new_module = ColumnParallelLinear_Compress(
- mpu,
- old_module.input_size,
- old_module.output_size,
- gather_output=old_module.gather_output,
- skip_bias_add=old_module.skip_bias_add,
- bias=need_bias).to(device=old_module.weight.device,
- dtype=old_module.weight.dtype)
- new_module.weight.data = old_module.weight.data
- if need_bias:
- new_module.bias.data = old_module.bias.data
- elif mpu is not None and (isinstance(old_module,
- RowParallelLinear_Compress)
- or isinstance(old_module,
- mpu.RowParallelLinear)):
- if isinstance(old_module, RowParallelLinear_Compress):
- new_module = old_module
- else:
- new_module = RowParallelLinear_Compress(
- mpu,
- old_module.input_size,
- old_module.output_size,
- input_is_parallel=old_module.input_is_parallel,
- skip_bias_add=old_module.skip_bias_add,
- bias=need_bias).to(device=old_module.weight.device,
- dtype=old_module.weight.dtype)
- new_module.weight.data = old_module.weight.data
- if need_bias:
- new_module.bias.data = old_module.bias.data
- else:
- new_module = None
- if compression_technique is not None:
- for k, v in compression_technique.items():
- if k == SPARSE_PRUNING:
- if v[SPARSE_PRUNING_ENABLED]:
- new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO],
- v[SPARSE_PRUNING_METHOD])
- elif k == ROW_PRUNING:
- if v[ROW_PRUNING_ENABLED]:
- new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO],
- v[ROW_PRUNING_METHOD])
- elif k == HEAD_PRUNING:
- if v[HEAD_PRUNING_ENABLED]:
- new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO],
- v[HEAD_PRUNING_METHOD],
- v[HEAD_PRUNING_NUM_HEADS])
- elif k == ACTIVATION_QUANTIZATION:
- if v[ACTIVATION_QUANTIZATION_ENABLED]:
- new_module.enable_activation_quantization(
- v[ACTIVATION_QUANTIZE_BITS],
- v[ACTIVATION_QUANTIZE_TYPE],
- v[ACTIVATION_QUANTIZE_RANGE])
- elif k == WEIGHT_QUANTIZATION:
- if v[WEIGHT_QUANTIZE_ENABLED]:
- new_module.enable_weight_quantization(
- v[WEIGHT_QUANTIZE_START_BITS],
- v[WEIGHT_QUANTIZE_TARGET_BITS],
- v[WEIGHT_QUANTIZATION_PERIOD],
- v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
- v[WEIGHT_QUANTIZE_TYPE],
- v[WEIGHT_QUANTIZE_GROUPS])
- elif k == CHANNEL_PRUNING:
- if v[CHANNEL_PRUNING_ENABLED]:
- new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO],
- v[CHANNEL_PRUNING_METHOD])
- else:
- raise NotImplementedError(
- 'Compression technique {} is not implemented'.format(k))
- # Replace the old module with the new one
- recursive_setattr(model, module_name, new_module)
- def is_module_compressible(module, mpu=None):
- ret = isinstance(module, torch.nn.Linear) or \
- isinstance(module, torch.nn.Conv2d) or \
- isinstance(module, torch.nn.Embedding) or \
- isinstance(module, torch.nn.BatchNorm2d)
- if mpu is not None:
- ret = ret or isinstance(module,
- mpu.RowParallelLinear) or isinstance(
- module,
- mpu.ColumnParallelLinear)
- return ret
- def compression_preparation(model, compression_techinique_list, mpu):
- """
- Prepare the compression techniques of a model.
- Args:
- model (`torch.nn.Module`)
- The model to prepare the compression techniques of.
- compression_techinique_list (`list`)
- The list of compression techniques to prepare the model to.
- list[]
- """
- # Here we first replace all module with our linear wrapper
- for module_name, module in model.named_modules():
- if is_module_compressible(module, mpu):
- module_replacement(model, module_name, mpu=mpu)
- for module_name_lists, _, compression_technique in compression_techinique_list:
- for mnl in module_name_lists:
- for module_name in mnl:
- module_replacement(model, module_name, compression_technique)
- return model
- def fix_compression(model,
- module_name,
- compression_technique,
- mask=None,
- dim_reduction=False):
- """
- Fix the compression technique of a module.
- Args:
- model (`torch.nn.Module`)
- The model to fix the compression technique of.
- module_name (`str`)
- The name of the module to fix the compression technique of.
- compression_technique (`str`)
- The compression technique to fix the module to.
- """
- # Here we can make things much simpler by just replacing the module
- module = recursive_getattr(model, module_name)
- for k, v in compression_technique.items():
- if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[
- WEIGHT_QUANTIZE_ENABLED]:
- return module.fix_weight_quantization()
- elif k == SPARSE_PRUNING and v[SPARSE_PRUNING_ENABLED]:
- return module.fix_sparse_pruning_helper()
- elif k == ROW_PRUNING and (v[ROW_PRUNING_ENABLED] or mask is not None):
- return module.fix_row_col_pruning_helper(mask, dim_reduction=dim_reduction)
- elif k == HEAD_PRUNING and (v[HEAD_PRUNING_ENABLED] or mask is not None):
- return module.fix_head_pruning_helper(mask,
- v[HEAD_PRUNING_NUM_HEADS],
- dim_reduction=dim_reduction)
- elif k == CHANNEL_PRUNING and (v[CHANNEL_PRUNING_ENABLED] or mask is not None):
- return module.fix_channel_pruning_helper(mask, dim_reduction=dim_reduction)
- def convert_conv1d_to_linear(model, convert_type):
- '''
- This is a help function to convert conv1d to linear (e.g., convert GPT2 from HF)
- '''
- if hasattr(model, 'module'):
- c_model = model.module
- else:
- c_model = model
- for name, module in c_model.named_modules():
- if isinstance(module, convert_type):
- old_module = recursive_getattr(c_model, name)
- new_module = torch.nn.Linear(
- old_module.weight.data.size(0),
- old_module.weight.data.size(1),
- bias=True if old_module.bias is not None else False)
- new_module.weight.data = old_module.weight.data.t().contiguous()
- if new_module.bias is not None:
- new_module.bias.data = old_module.bias.data.view(-1)
- recursive_setattr(c_model, name, new_module)
- return model
|