123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- '''Copyright The Microsoft DeepSpeed Team'''
- from .base import *
- from .features.megatron import MegatronContainer
- from deepspeed.model_implementations.transformers.ds_megatron_gpt import DeepSpeedMegatronGPTInference
- import torch
- from ..policy import TransformerPolicy
- from packaging import version as pkg_version
- class DS_MegatronGPTContainer(MegatronContainer, BaseTransformerContainer):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- # All model specific things should be defined here instead of the base class.
- def create_module(self, config=None):
- _config = config if config is not None else self.ds_model_config
- self.module = DeepSpeedMegatronGPTInference(_config, mp_group=self.mp_group)
- self.module.config.scale_attention = self.scale_attention
- if self.megatron_v2:
- self.module.config.rotate_half = True
- self.module.config.rotate_every_two = False
- return self.module
- # TODO: Megatron GPT MoE inherits from Megatron policy and replaces mlp
- # TODO: Generalize MoE overall goal, expand beyond Megatron
- class MegatronLayerPolicy(TransformerPolicy):
- _orig_layer_class = None
- version = 0
- moe_type = 'standard'
- megatron_v2 = True
- use_mup = False
- def __init__(self, client_module, inference=True):
- super().__init__(inference,
- megatron_v2=MegatronLayerPolicy.megatron_v2,
- use_mup=MegatronLayerPolicy.use_mup)
- self.client_module = client_module
- # we use megatron version to differentiate between the old and new
- # megatron-lm source code
- if MegatronLayerPolicy._orig_layer_class is None:
- if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
- MegatronLayerPolicy._orig_layer_class = None
- else:
- try:
- from megatron.model.transformer import ParallelTransformerLayer
- MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer
- except ImportError:
- MegatronLayerPolicy._orig_layer_class = None
- def get_hidden_heads(self):
- return self.client_module.attention.query_key_value.weight.shape[1], \
- self.client_module.attention.num_attention_heads
- def attention(self):
- if self.inference:
- if MegatronLayerPolicy.version == 0:
- attention = self.client_module.attention
- else:
- attention = self.client_module.self_attention
- return attention.query_key_value.weight, \
- attention.query_key_value.bias, \
- attention.dense.weight, \
- attention.dense.bias
- def mlp(self, moe_type='standard'):
- from deepspeed.moe.utils import has_moe_layers
- moe, _ = has_moe_layers(self.client_module)
- if moe:
- moe_experts = self.client_module.mlp.deepspeed_moe.experts.deepspeed_experts if moe_type == 'standard' else \
- self.client_module.mlp.moe.deepspeed_moe.experts.deepspeed_experts
- num_experts = len(moe_experts)
- if moe_type == 'standard':
- return [moe_experts[i].dense_h_to_4h.weight for i in range(num_experts)], \
- [moe_experts[i].dense_h_to_4h.bias for i in range(num_experts)], \
- [moe_experts[i].dense_4h_to_h.weight for i in range(num_experts)], \
- [moe_experts[i].dense_4h_to_h.bias for i in range(num_experts)]
- else:
- return [moe_experts[i].dense_h_to_4h.weight for i in range(num_experts)], \
- [moe_experts[i].dense_h_to_4h.bias for i in range(num_experts)], \
- [moe_experts[i].dense_4h_to_h.weight for i in range(num_experts)], \
- [moe_experts[i].dense_4h_to_h.bias for i in range(num_experts)], \
- self.client_module.mlp.mlp.dense_h_to_4h.weight, \
- self.client_module.mlp.mlp.dense_h_to_4h.bias, \
- self.client_module.mlp.mlp.dense_4h_to_h.weight, \
- self.client_module.mlp.mlp.dense_4h_to_h.bias, \
- self.client_module.mlp.coefficient.weight
- else:
- return self.client_module.mlp.dense_h_to_4h.weight, \
- self.client_module.mlp.dense_h_to_4h.bias, \
- self.client_module.mlp.dense_4h_to_h.weight, \
- self.client_module.mlp.dense_4h_to_h.bias
- def layernorm(self):
- return self.client_module.post_attention_layernorm.weight, \
- self.client_module.post_attention_layernorm.bias, \
- self.client_module.input_layernorm.weight, \
- self.client_module.input_layernorm.bias
|