1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- '''Copyright The Microsoft DeepSpeed Team'''
- import torch
- def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=False):
- """ Quantize bert-style transformer layers with DeepSpeed's transformer layer
- Arguments:
- orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
- e.g., transformers.modeling_bert.BertLayer.
- model (torch.nn.Module): user's nn.module representing their model
- megatron (bool): megatron model-parallel implementation (this is supported for inference only)
- preln (bool): does the original layer implementation do pre or post layer norm?
- Note: For Bert kind of models, we inject based on the DeepSpeed-Example models, if not setting huggingface flag.
- Returns:
- Updated nn.module with quantized transformer layers
- """
- def quantize_weight(weight):
- return weight.to(torch.int8)
- def megatron_layer_quantize(layer):
- layer.attention.query_key_value.weight.data = quantize_weight(
- layer.attention.query_key_value.weight.data)
- layer.attention.dense.weight.data = quantize_weight(
- layer.attention.dense.weight.data)
- layer.mlp.dense_h_to_4h.weight.data = quantize_weight(
- layer.mlp.dense_h_to_4h.weight.data)
- layer.mlp.dense_4h_to_h.weight.data = quantize_weight(
- layer.mlp.dense_4h_to_h.weight.data)
- def bert_layer_quantize(layer):
- layer.attention.self.query.weight.data = quantize_weight(
- layer.attention.self.query.weight.data)
- layer.attention.self.key.weight.data = quantize_weight(
- layer.attention.self.key.weight.data)
- layer.attention.self.value.weight.data = quantize_weight(
- layer.attention.self.value.weight.data)
- layer.attention.output.dense.weight.data = quantize_weight(
- layer.attention.output.dense.weight.data)
- if preln:
- layer.intermediate.dense_act.weight.data = quantize_weight(
- layer.intermediate.dense_act.weight.data)
- else:
- layer.intermediate.dense.weight.data = quantize_weight(
- layer.intermediate.dense.weight.data)
- layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data)
- def quantize_fn(child):
- if megatron:
- # Quantize megatron GPT2 / GPT3 trained model
- megatron_layer_quantize(child)
- else:
- # Quantize either DeepSpeed or HuggingFace trained model
- bert_layer_quantize(child)
- return child
- return quantize_module(model=model,
- orig_class=orig_layer_impl,
- quantize_fn=quantize_fn)
- def quantize_module(model, orig_class, quantize_fn):
- policy = {orig_class: quantize_fn}
- return _quantize_module(model, policy)
- def _quantize_module(model, policies):
- for name, child in model.named_children():
- if child.__class__ in policies:
- orig = repr(child)
- setattr(model, name, policies[child.__class__](child))
- new = getattr(model, name)
- else:
- _quantize_module(child, policies)
- return model
|