module_quantize.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=False):
  6. """ Quantize bert-style transformer layers with DeepSpeed's transformer layer
  7. Arguments:
  8. orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
  9. e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
  10. model (torch.nn.Module): user's nn.module representing their model
  11. megatron (bool): megatron model-parallel implementation (this is supported for inference only)
  12. preln (bool): does the original layer implementation do pre or post layer norm?
  13. Note: For Bert kind of models, we inject based on the DeepSpeed-Example models, if not setting huggingface flag.
  14. Returns:
  15. Updated nn.module with quantized transformer layers
  16. """
  17. def quantize_weight(weight):
  18. return weight.to(torch.int8)
  19. def megatron_layer_quantize(layer):
  20. layer.attention.query_key_value.weight.data = quantize_weight(layer.attention.query_key_value.weight.data)
  21. layer.attention.dense.weight.data = quantize_weight(layer.attention.dense.weight.data)
  22. layer.mlp.dense_h_to_4h.weight.data = quantize_weight(layer.mlp.dense_h_to_4h.weight.data)
  23. layer.mlp.dense_4h_to_h.weight.data = quantize_weight(layer.mlp.dense_4h_to_h.weight.data)
  24. def bert_layer_quantize(layer):
  25. layer.attention.self.query.weight.data = quantize_weight(layer.attention.self.query.weight.data)
  26. layer.attention.self.key.weight.data = quantize_weight(layer.attention.self.key.weight.data)
  27. layer.attention.self.value.weight.data = quantize_weight(layer.attention.self.value.weight.data)
  28. layer.attention.output.dense.weight.data = quantize_weight(layer.attention.output.dense.weight.data)
  29. if preln:
  30. layer.intermediate.dense_act.weight.data = quantize_weight(layer.intermediate.dense_act.weight.data)
  31. else:
  32. layer.intermediate.dense.weight.data = quantize_weight(layer.intermediate.dense.weight.data)
  33. layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data)
  34. def quantize_fn(child):
  35. if megatron:
  36. # Quantize megatron GPT2 / GPT3 trained model
  37. megatron_layer_quantize(child)
  38. else:
  39. # Quantize either DeepSpeed or HuggingFace trained model
  40. bert_layer_quantize(child)
  41. return child
  42. return quantize_module(model=model, orig_class=orig_layer_impl, quantize_fn=quantize_fn)
  43. def quantize_module(model, orig_class, quantize_fn):
  44. policy = {orig_class: quantize_fn}
  45. return _quantize_module(model, policy)
  46. def _quantize_module(model, policies):
  47. for name, child in model.named_children():
  48. if child.__class__ in policies:
  49. orig = repr(child)
  50. setattr(model, name, policies[child.__class__](child))
  51. new = getattr(model, name)
  52. else:
  53. _quantize_module(child, policies)
  54. return model