# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from .base import * from .features import MetaTensorContainer, HybridSplitQKVContainer from deepspeed.model_implementations.transformers.ds_opt import DeepSpeedOPTInference import torch from torch.nn.parameter import Parameter from ..policy import TransformerPolicy from ..policy import transformer_param_names from ..policy import maybe_copy from ..policy import maybe_copy_qkv from ..policy import maybe_get_lora from deepspeed.utils.types import ActivationFuncType class DS_OPTContainer(MetaTensorContainer, HybridSplitQKVContainer, BaseTransformerContainer): def __init__(self, **kwargs): super().__init__(**kwargs) # All model specific things should be defined here instead of the base class. def create_module(self, config=None): _config = config if config is not None else self.ds_model_config self.module = DeepSpeedOPTInference(_config, mp_group=self.mp_group) self.module.config.scale_attention = self.scale_attention return self.module def set_lora_params(self): """ Necessary to implement for `HybridEngineContainer` """ self.lora_params = [ maybe_get_lora(p) for p in [ self.policy.client_module.fc1, self.policy.client_module.fc2, self.policy.client_module.self_attn.q_proj, self.policy.client_module.self_attn.k_proj, self.policy.client_module.self_attn.v_proj, self.policy.client_module.self_attn.out_proj, ] ] def set_q_k_v(self): """ Necessary to implement for `HybridSplitQKVContainer` """ self.qw = self.policy.client_module.self_attn.q_proj.weight self.qb = self.policy.client_module.self_attn.q_proj.bias self.kw = self.policy.client_module.self_attn.k_proj.weight self.kb = self.policy.client_module.self_attn.k_proj.bias self.vw = self.policy.client_module.self_attn.v_proj.weight self.vb = self.policy.client_module.self_attn.v_proj.bias def get_lora_matched_pair(self): fc1_lora, fc2_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params() ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (out_lora, self.dense_w), (q_lora, self.qw), (k_lora, self.kw), (v_lora, self.vw)] return ret def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): param_names = ( 'self_attn.q_proj.weight', \ 'self_attn.k_proj.weight', \ 'self_attn.v_proj.weight', \ 'self_attn.q_proj.bias', \ 'self_attn.k_proj.bias', \ 'self_attn.v_proj.bias', \ 'self_attn.out_proj.weight', \ 'self_attn.out_proj.bias', \ 'fc1.weight', \ 'fc1.bias', \ 'fc2.weight', \ 'fc2.bias', \ 'final_layer_norm.weight', \ 'final_layer_norm.bias', \ 'self_attn_layer_norm.weight', \ 'self_attn_layer_norm.bias' ) for i in range(0, 6, 3): maybe_copy_qkv(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i // 3], [prefix + param_names[i], prefix + param_names[i + 1], prefix + param_names[i + 2]], split_qkv=self.policy.split_qkv) for i in range(6, 8): maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4], prefix + param_names[i]) for i in range(8, 14): maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4], prefix + param_names[i]) for i in range(14, 16): maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4], prefix + param_names[i]) class HFOPTLayerPolicy(TransformerPolicy): _orig_layer_class = None def __init__(self, client_module, inference=True, use_load_prefix=True): super().__init__(inference, linear_layer=True, pre_attn_norm=True, use_load_prefix=use_load_prefix) self.client_module = client_module try: import transformers HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer except: HFOPTLayerPolicy._orig_layer_class = None if hasattr(TransformerPolicy, "hf_model_config") and hasattr(TransformerPolicy.hf_model_config, "activation_function"): if TransformerPolicy.hf_model_config.activation_function == "relu": self.mlp_act_func_type = ActivationFuncType.ReLU elif TransformerPolicy.hf_model_config.activation_function in ["gelu", "gelu_new"]: self.mlp_act_func_type = ActivationFuncType.GELU else: raise ValueError("Unsupported activation function: {}".format( TransformerPolicy.hf_model_config.activation_function)) else: self.mlp_act_func_type = ActivationFuncType.ReLU # default def get_hidden_heads(self): return self.client_module.self_attn.embed_dim, \ self.client_module.self_attn.num_heads, \ self.client_module.self_attn_layer_norm.eps, \ DEFAULT_INTERMEDIATE_SIZE def attention(self, enable_training=False): qw = self.client_module.self_attn.q_proj.weight qb = self.client_module.self_attn.q_proj.bias kw = self.client_module.self_attn.k_proj.weight kb = self.client_module.self_attn.k_proj.bias vw = self.client_module.self_attn.v_proj.weight vb = self.client_module.self_attn.v_proj.bias qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training) qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training) return qkvw, \ qkvb, \ self.client_module.self_attn.out_proj.weight, \ self.client_module.self_attn.out_proj.bias def mlp(self, enable_training=False): return self.client_module.fc1.weight, \ self.client_module.fc1.bias, \ self.client_module.fc2.weight, \ self.client_module.fc2.bias def layernorm(self): return self.client_module.final_layer_norm.weight, \ self.client_module.final_layer_norm.bias, \ self.client_module.self_attn_layer_norm.weight, \ self.client_module.self_attn_layer_norm.bias