gptneox.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from .base import *
  3. from .features.meta_tensor import MetaTensorContainer
  4. from .features.megatron import MegatronContainer
  5. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  6. import torch
  7. from ..policy import TransformerPolicy
  8. from ..policy import transformer_param_names
  9. from ..policy import maybe_copy
  10. from packaging import version as pkg_version
  11. class DS_GPTNEOXContainer(MetaTensorContainer,
  12. MegatronContainer,
  13. BaseTransformerContainer):
  14. def __init__(self, **kwargs):
  15. super().__init__(**kwargs)
  16. # All model specific things should be defined here instead of the base class.
  17. def create_module(self, config=None):
  18. _config = config if config is not None else self.ds_model_config
  19. self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
  20. self.module.config.scale_attention = self.scale_attention
  21. if self.megatron_v2:
  22. self.module.config.rotate_half = True
  23. self.module.config.rotate_every_two = False
  24. return self.module
  25. def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
  26. param_names = (
  27. 'attention.query_key_value.weight', \
  28. 'attention.query_key_value.bias', \
  29. 'attention.dense.weight', \
  30. 'attention.dense.bias', \
  31. 'mlp.dense_h_to_4h.weight', \
  32. 'mlp.dense_h_to_4h.bias', \
  33. 'mlp.dense_4h_to_h.weight', \
  34. 'mlp.dense_4h_to_h.bias', \
  35. 'post_attention_layernorm.weight', \
  36. 'post_attention_layernorm.bias', \
  37. 'input_layernorm.weight', \
  38. 'input_layernorm.bias'
  39. )
  40. for i in range(0, 2):
  41. maybe_copy(module.attention,
  42. sd,
  43. weight_quantizer,
  44. mp_replace,
  45. transformer_param_names[i],
  46. prefix + param_names[i],
  47. qkv=True,
  48. megatron_v2=self.policy.is_megatron_v2,
  49. split_qkv=self.policy.split_qkv,
  50. heads=self.policy.client_module.attention.num_attention_heads)
  51. for i in range(2, 4):
  52. maybe_copy(module.attention,
  53. sd,
  54. weight_quantizer,
  55. mp_replace,
  56. transformer_param_names[i],
  57. prefix + param_names[i])
  58. for i in range(4, 10):
  59. maybe_copy(module.mlp,
  60. sd,
  61. weight_quantizer,
  62. mp_replace,
  63. transformer_param_names[i],
  64. prefix + param_names[i])
  65. for i in range(10, 12):
  66. maybe_copy(module,
  67. sd,
  68. weight_quantizer,
  69. mp_replace,
  70. transformer_param_names[i],
  71. prefix + param_names[i])
  72. class GPTNEOXLayerPolicy(TransformerPolicy):
  73. _orig_layer_class = None
  74. version = 0
  75. def __init__(self, client_module, inference=True, megatron_v2=True, split_qkv=False):
  76. super().__init__(inference, megatron_v2=megatron_v2, split_qkv=split_qkv)
  77. self.client_module = client_module
  78. if GPTNEOXLayerPolicy._orig_layer_class is None:
  79. if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
  80. GPTNEOXLayerPolicy._orig_layer_class = None
  81. else:
  82. try:
  83. from transformers import GPTNeoXLayer
  84. GPTNEOXLayerPolicy._orig_layer_class = GPTNeoXLayer
  85. except ImportError:
  86. GPTNEOXLayerPolicy._orig_layer_class = None
  87. def get_hidden_heads(self):
  88. if GPTNEOXLayerPolicy.version == 0:
  89. attention = self.client_module.attention
  90. else:
  91. attention = self.client_module.self_attention
  92. return self.client_module.attention.query_key_value.weight.shape[1], \
  93. self.client_module.attention.num_attention_heads
  94. def attention(self):
  95. if GPTNEOXLayerPolicy.version == 0:
  96. attention = self.client_module.attention
  97. else:
  98. attention = self.client_module.self_attention
  99. return attention.query_key_value.weight, \
  100. attention.query_key_value.bias, \
  101. attention.dense.weight, \
  102. attention.dense.bias
  103. def mlp(self):
  104. return self.client_module.mlp.dense_h_to_4h.weight, \
  105. self.client_module.mlp.dense_h_to_4h.bias, \
  106. self.client_module.mlp.dense_4h_to_h.weight, \
  107. self.client_module.mlp.dense_4h_to_h.bias
  108. def layernorm(self):
  109. return self.client_module.post_attention_layernorm.weight, \
  110. self.client_module.post_attention_layernorm.bias, \
  111. self.client_module.input_layernorm.weight, \
  112. self.client_module.input_layernorm.bias