module_quantize.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import copy
  2. import torch
  3. import deepspeed
  4. def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=False):
  5. """ Quantize bert-style transformer layers with DeepSpeed's transformer layer
  6. Arguments:
  7. orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
  8. e.g., transformers.modeling_bert.BertLayer.
  9. model (torch.nn.Module): user's nn.module representing their model
  10. megatron (bool): megatron model-parallel implementation (this is supported for inference only)
  11. preln (bool): does the original layer implementation do pre or post layer norm?
  12. Note: For Bert kind of models, we inject based on the DeepSpeed-Example models, if not setting huggingface flag.
  13. Returns:
  14. Updated nn.module with quantized transformer layers
  15. """
  16. def quantize_weight(weight):
  17. return weight.to(torch.int8)
  18. def megatron_layer_quantize(layer):
  19. layer.attention.query_key_value.weight.data = quantize_weight(
  20. layer.attention.query_key_value.weight.data)
  21. layer.attention.dense.weight.data = quantize_weight(
  22. layer.attention.dense.weight.data)
  23. layer.mlp.dense_h_to_4h.weight.data = quantize_weight(
  24. layer.mlp.dense_h_to_4h.weight.data)
  25. layer.mlp.dense_4h_to_h.weight.data = quantize_weight(
  26. layer.mlp.dense_4h_to_h.weight.data)
  27. def bert_layer_quantize(layer):
  28. layer.attention.self.query.weight.data = quantize_weight(
  29. layer.attention.self.query.weight.data)
  30. layer.attention.self.key.weight.data = quantize_weight(
  31. layer.attention.self.key.weight.data)
  32. layer.attention.self.value.weight.data = quantize_weight(
  33. layer.attention.self.value.weight.data)
  34. layer.attention.output.dense.weight.data = quantize_weight(
  35. layer.attention.output.dense.weight.data)
  36. if preln:
  37. layer.intermediate.dense_act.weight.data = quantize_weight(
  38. layer.intermediate.dense_act.weight.data)
  39. else:
  40. layer.intermediate.dense.weight.data = quantize_weight(
  41. layer.intermediate.dense.weight.data)
  42. layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data)
  43. def quantize_fn(child):
  44. if megatron:
  45. # Quantize megatron GPT2 / GPT3 trained model
  46. megatron_layer_quantize(child)
  47. else:
  48. # Quantize either DeepSpeed or HuggingFace trained model
  49. bert_layer_quantize(child)
  50. return child
  51. return quantize_module(model=model,
  52. orig_class=orig_layer_impl,
  53. quantize_fn=quantize_fn)
  54. def quantize_module(model, orig_class, quantize_fn):
  55. policy = {orig_class: quantize_fn}
  56. return _quantize_module(model, policy)
  57. def _quantize_module(model, policies):
  58. for name, child in model.named_children():
  59. if child.__class__ in policies:
  60. orig = repr(child)
  61. setattr(model, name, policies[child.__class__](child))
  62. new = getattr(model, name)
  63. else:
  64. _quantize_module(child, policies)
  65. return model