gptneo.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .base import *
  3. from .features.meta_tensor import MetaTensorContainer
  4. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  5. import torch
  6. from torch.nn.parameter import Parameter
  7. from ..policy import TransformerPolicy
  8. from ..policy import transformer_param_names
  9. from ..policy import maybe_copy
  10. from ..policy import maybe_copy_qkv
  11. class DS_GPTNEOContainer(MetaTensorContainer, BaseTransformerContainer):
  12. def __init__(self, **kwargs):
  13. super().__init__(**kwargs)
  14. # All model specific things should be defined here instead of the base class.
  15. def create_module(self, config=None):
  16. _config = config if config is not None else self.ds_model_config
  17. self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
  18. self.module.config.scale_attention = self.scale_attention
  19. return self.module
  20. def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
  21. param_names = (
  22. 'attn.attention.q_proj.weight', \
  23. 'attn.attention.k_proj.weight', \
  24. 'attn.attention.v_proj.weight', \
  25. 'attn.attention.out_proj.weight', \
  26. 'attn.attention.out_proj.bias', \
  27. 'mlp.c_fc.weight', \
  28. 'mlp.c_fc.bias', \
  29. 'mlp.c_proj.weight', \
  30. 'mlp.c_proj.bias', \
  31. 'ln_2.weight', \
  32. 'ln_2.bias', \
  33. 'ln_1.weight', \
  34. 'ln_1.bias'
  35. )
  36. maybe_copy_qkv(
  37. module.attention,
  38. sd,
  39. weight_quantizer,
  40. mp_replace,
  41. 'attn_qkvw',
  42. [prefix + param_names[0],
  43. prefix + param_names[1],
  44. prefix + param_names[2]],
  45. split_qkv=self.policy.split_qkv)
  46. for i in range(3, 5):
  47. maybe_copy(module.attention,
  48. sd,
  49. weight_quantizer,
  50. mp_replace,
  51. transformer_param_names[i - 1],
  52. prefix + param_names[i])
  53. for i in range(5, 11):
  54. maybe_copy(module.mlp,
  55. sd,
  56. weight_quantizer,
  57. mp_replace,
  58. transformer_param_names[i - 1],
  59. prefix + param_names[i])
  60. for i in range(11, 13):
  61. maybe_copy(module,
  62. sd,
  63. weight_quantizer,
  64. mp_replace,
  65. transformer_param_names[i - 1],
  66. prefix + param_names[i])
  67. class HFGPTNEOLayerPolicy(TransformerPolicy):
  68. def __init__(self, client_module, inference=True):
  69. super().__init__(inference, scale_attention=False)
  70. self.client_module = client_module
  71. try:
  72. import transformers
  73. HFGPTNEOLayerPolicy._orig_layer_class = transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock
  74. except:
  75. HFGPTNEOLayerPolicy._orig_layer_class = None
  76. def get_hidden_heads(self):
  77. return self.client_module.attn.attention.q_proj.weight.shape[1], \
  78. self.client_module.attn.attention.num_heads
  79. def attention(self):
  80. qw = self.client_module.attn.attention.q_proj.weight
  81. kw = self.client_module.attn.attention.k_proj.weight
  82. vw = self.client_module.attn.attention.v_proj.weight
  83. qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
  84. return qkvw, \
  85. None, \
  86. self.client_module.attn.attention.out_proj.weight, \
  87. self.client_module.attn.attention.out_proj.bias
  88. def mlp(self):
  89. return self.client_module.mlp.c_fc.weight, \
  90. self.client_module.mlp.c_fc.bias, \
  91. self.client_module.mlp.c_proj.weight, \
  92. self.client_module.mlp.c_proj.bias
  93. def layernorm(self):
  94. return self.client_module.ln_2.weight, \
  95. self.client_module.ln_2.bias, \
  96. self.client_module.ln_1.weight, \
  97. self.client_module.ln_1.bias