123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from .base import *
- from .features.meta_tensor import MetaTensorContainer
- from .features.hybrid_engine import HybridEngineContainer
- from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
- from ..policy import TransformerPolicy
- from ..policy import transformer_param_names
- from ..policy import maybe_copy
- from ..policy import maybe_get_lora
- supported_models = {None}
- class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- # All model specific things should be defined here instead of the base class.
- self.bigscience_bloom = True
- self.triangular_masking = False
- def create_module(self, config=None):
- _config = config if config is not None else self.ds_model_config
- self.module = DeepSpeedBloomInference(_config, mp_group=self.mp_group)
- self.module.config.scale_attention = self.scale_attention
- self.module.config.invert_mask = False
- return self.module
- def attention_qkv_mp(self, mp_replace, reversed_dim=False):
- self.module.attention.attn_qkvw = mp_replace.copy(self.module.attention.attn_qkvw, self.qkvw)
- self.module.attention.attn_qkvb = mp_replace.copy(self.module.attention.attn_qkvb, self.qkvb)
- def get_lora_matched_pair(self):
- """
- Necessary to implement for `HybridEngineContainer`
- """
- fc1_lora, fc2_lora, qkv_lora, out_lora = self.get_lora_params()
- ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (qkv_lora, self.qkvw), (out_lora, self.dense_w)]
- return ret
- def set_lora_params(self):
- """
- Necessary to implement for `HybridEngineContainer`
- """
- self.lora_params = [
- maybe_get_lora(p) for p in [
- self.policy.client_module.mlp.dense_h_to_4h, self.policy.client_module.mlp.dense_4h_to_h, self.policy.
- client_module.self_attention.query_key_value, self.policy.client_module.self_attention.dense
- ]
- ]
- def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
- param_names = (
- 'self_attention.query_key_value.weight', \
- 'self_attention.query_key_value.bias', \
- 'self_attention.dense.weight', \
- 'self_attention.dense.bias', \
- 'mlp.dense_h_to_4h.weight', \
- 'mlp.dense_h_to_4h.bias', \
- 'mlp.dense_4h_to_h.weight', \
- 'mlp.dense_4h_to_h.bias', \
- 'post_attention_layernorm.weight', \
- 'post_attention_layernorm.bias', \
- 'input_layernorm.weight', \
- 'input_layernorm.bias'
- )
- for i in range(0, 2):
- maybe_copy(module.attention,
- sd,
- weight_quantizer,
- mp_replace,
- transformer_param_names[i],
- prefix + param_names[i],
- qkv=True,
- megatron_v2=self.policy.is_megatron_v2,
- split_qkv=self.policy.split_qkv)
- for i in range(2, 4):
- maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i],
- prefix + param_names[i])
- for i in range(4, 10):
- maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i],
- prefix + param_names[i])
- for i in range(10, 12):
- maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i], prefix + param_names[i])
- class BLOOMLayerPolicy(TransformerPolicy):
- _orig_layer_class = None
- def __init__(self, client_module, inference=True, use_load_prefix=True, split_qkv=False):
- super().__init__(inference, linear_layer=True, use_load_prefix=use_load_prefix, split_qkv=split_qkv)
- self.client_module = client_module
- try:
- import transformers
- BLOOMLayerPolicy._orig_layer_class = transformers.models.bloom.modeling_bloom.BloomBlock
- global supported_models
- supported_models.update({transformers.models.bloom.modeling_bloom.BloomModel})
- except Exception as e:
- print(f"WARNING! Setting BLOOMLayerPolicy._orig_layer_class to None due to Exception: {e}")
- BLOOMLayerPolicy._orig_layer_class = None
- def get_hidden_heads(self):
- return self.client_module.self_attention.hidden_size, \
- self.client_module.self_attention.num_heads, \
- self.client_module.input_layernorm.eps, \
- DEFAULT_INTERMEDIATE_SIZE
- def attention(self, enable_training=False):
- return self.client_module.self_attention.query_key_value.weight, \
- self.client_module.self_attention.query_key_value.bias, \
- self.client_module.self_attention.dense.weight, \
- self.client_module.self_attention.dense.bias,
- def mlp(self, enable_training=False):
- return self.client_module.mlp.dense_h_to_4h.weight, \
- self.client_module.mlp.dense_h_to_4h.bias, \
- self.client_module.mlp.dense_4h_to_h.weight, \
- self.client_module.mlp.dense_4h_to_h.bias
- def layernorm(self):
- return self.client_module.post_attention_layernorm.weight, \
- self.client_module.post_attention_layernorm.bias, \
- self.client_module.input_layernorm.weight, \
- self.client_module.input_layernorm.bias
|