123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- import math
- import torch.nn as nn
- from deepspeed.accelerator import get_accelerator
- from deepspeed import comm as dist
- from ..op_binding import MLPGemmOp, VectorMatMulOp, GELUGemmOp, ResidualAddOp
- class TritonMLP(nn.Module):
- def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, mlp_extra_grouping=False):
- super(TritonMLP, self).__init__()
- self.config = config
- data_type = self.config.dtype
- data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
- device = get_accelerator().current_device_name()
- self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
- requires_grad=False)
- self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
- requires_grad=False)
- intm_size_per_partition = self.config.intermediate_size // self.config.mp_size
- self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
- intm_size_per_partition,
- dtype=data_type,
- device=device),
- requires_grad=False)
- self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device),
- requires_grad=False)
- self.output_w = nn.Parameter(torch.empty(intm_size_per_partition,
- self.config.hidden_size,
- dtype=data_type,
- device=device),
- requires_grad=False)
- self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
- requires_grad=False)
- # used for quantization
- self.q_scales = q_scales
- self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups
- self.merge_count = int(math.log2(merge_count))
- self.mp_group = mp_group
- self.mlp_gemm_func = MLPGemmOp(config)
- self.vector_matmul_func = VectorMatMulOp(config)
- self.fused_gemm_gelu = GELUGemmOp(config)
- self.residual_add_func = ResidualAddOp(config)
- def forward(self, input, residual, residual_norm, bias):
- residual_add = None
- if self.attn_nw is None:
- output = self.fused_gemm_gelu(input=residual_norm,
- weight=self.inter_w,
- bias=self.inter_b,
- weight_out=self.output_w)
- else:
- output, residual_add = self.mlp_gemm_func(input=input,
- residual=residual,
- input_bias=bias,
- weight_interm=self.inter_w,
- weight_out=self.output_w,
- bias=self.inter_b,
- gamma=self.attn_nw,
- beta=self.attn_nb)
- residual = self.residual_add_func(hidden_state=output,
- residual=residual,
- attention_output=input,
- attention_bias=bias if bias is not None else self.output_b,
- final_bias=self.output_b,
- add_bias=bias is not None,
- residual_add=residual_add)
- if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
- dist.all_reduce(residual, group=self.mp_group)
- return residual
|