123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- '''Copyright The Microsoft DeepSpeed Team'''
- from .base import *
- from .features.meta_tensor import MetaTensorContainer
- from .features.megatron import MegatronContainer
- from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
- import torch
- from ..policy import TransformerPolicy
- from ..policy import transformer_param_names
- from ..policy import maybe_copy
- from packaging import version as pkg_version
- class DS_GPTNEOXContainer(MetaTensorContainer,
- MegatronContainer,
- BaseTransformerContainer):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- # All model specific things should be defined here instead of the base class.
- def create_module(self, config=None):
- _config = config if config is not None else self.ds_model_config
- self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
- self.module.config.scale_attention = self.scale_attention
- if self.megatron_v2:
- self.module.config.rotate_half = True
- self.module.config.rotate_every_two = False
- return self.module
- def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
- param_names = (
- 'attention.query_key_value.weight', \
- 'attention.query_key_value.bias', \
- 'attention.dense.weight', \
- 'attention.dense.bias', \
- 'mlp.dense_h_to_4h.weight', \
- 'mlp.dense_h_to_4h.bias', \
- 'mlp.dense_4h_to_h.weight', \
- 'mlp.dense_4h_to_h.bias', \
- 'post_attention_layernorm.weight', \
- 'post_attention_layernorm.bias', \
- 'input_layernorm.weight', \
- 'input_layernorm.bias'
- )
- for i in range(0, 2):
- maybe_copy(module.attention,
- sd,
- weight_quantizer,
- mp_replace,
- transformer_param_names[i],
- prefix + param_names[i],
- qkv=True,
- megatron_v2=self.policy.is_megatron_v2,
- split_qkv=self.policy.split_qkv,
- heads=self.policy.client_module.attention.num_attention_heads)
- for i in range(2, 4):
- maybe_copy(module.attention,
- sd,
- weight_quantizer,
- mp_replace,
- transformer_param_names[i],
- prefix + param_names[i])
- for i in range(4, 10):
- maybe_copy(module.mlp,
- sd,
- weight_quantizer,
- mp_replace,
- transformer_param_names[i],
- prefix + param_names[i])
- for i in range(10, 12):
- maybe_copy(module,
- sd,
- weight_quantizer,
- mp_replace,
- transformer_param_names[i],
- prefix + param_names[i])
- class GPTNEOXLayerPolicy(TransformerPolicy):
- _orig_layer_class = None
- version = 0
- def __init__(self, client_module, inference=True, megatron_v2=True, split_qkv=False):
- super().__init__(inference, megatron_v2=megatron_v2, split_qkv=split_qkv)
- self.client_module = client_module
- if GPTNEOXLayerPolicy._orig_layer_class is None:
- if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
- GPTNEOXLayerPolicy._orig_layer_class = None
- else:
- try:
- from transformers import GPTNeoXLayer
- GPTNEOXLayerPolicy._orig_layer_class = GPTNeoXLayer
- except ImportError:
- GPTNEOXLayerPolicy._orig_layer_class = None
- def get_hidden_heads(self):
- if GPTNEOXLayerPolicy.version == 0:
- attention = self.client_module.attention
- else:
- attention = self.client_module.self_attention
- return self.client_module.attention.query_key_value.weight.shape[1], \
- self.client_module.attention.num_attention_heads
- def attention(self):
- if GPTNEOXLayerPolicy.version == 0:
- attention = self.client_module.attention
- else:
- attention = self.client_module.self_attention
- return attention.query_key_value.weight, \
- attention.query_key_value.bias, \
- attention.dense.weight, \
- attention.dense.bias
- def mlp(self):
- return self.client_module.mlp.dense_h_to_4h.weight, \
- self.client_module.mlp.dense_h_to_4h.bias, \
- self.client_module.mlp.dense_4h_to_h.weight, \
- self.client_module.mlp.dense_4h_to_h.bias
- def layernorm(self):
- return self.client_module.post_attention_layernorm.weight, \
- self.client_module.post_attention_layernorm.bias, \
- self.client_module.input_layernorm.weight, \
- self.client_module.input_layernorm.bias
|