123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import deepspeed
- from deepspeed.ops.op_builder import InferenceBuilder
- import deepspeed.ops.transformer.inference.triton.matmul_ext as matmul_ext
- from deepspeed.ops.transformer.inference.triton.layer_norm import layer_norm, layer_norm_residual
- inference_module = None
- def vector_matmul_func(input, weight, async_op, q_scale, q_int8, transposed_mode):
- assert not transposed_mode and not async_op and not q_int8
- return matmul_ext.matmul(input, weight, bias=None, activation="", use_triton=True)
- def fused_gemm_gelu(input,
- weight,
- weight_scale,
- bias,
- weight_out,
- weight_out_scale,
- epsilon,
- pre_layer_norm,
- q_int8,
- async_op,
- transposed_mode,
- use_triton_ln=True):
- assert not transposed_mode
- # activation
- activation = "gelu"
- # intermediate fc in FF
- intm_out = matmul_ext.matmul(input, weight, bias=bias, activation=activation, use_triton=True)
- # output fc in FF
- ff_out = matmul_ext.matmul(
- intm_out,
- weight_out,
- bias=None,
- activation="", # bias added layer with residual_add + bias + layerNorm layer
- use_triton=True)
- return ff_out
- def linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads, transposed_mode=False):
- assert not transposed_mode and not do_flash_attn
- qkv_out = matmul_ext.matmul(input, weight, bias=(bias if add_bias else None), activation="", use_triton=True)
- return qkv_out
- def mlp_gemm_func(input,
- residual,
- input_bias,
- weight_interm,
- weight_out,
- bias,
- gamma,
- beta,
- epsilon,
- pre_layer_norm,
- mlp_after_attn,
- weight_interm_scale,
- weight_out_scale,
- q_int8,
- mlp_act_func_type,
- transposed_mode,
- use_triton_ln=True):
- assert not transposed_mode
- # residual add and layerNorm after attention
- if use_triton_ln:
- mlp_input = layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon)
- else:
- global inference_module
- if inference_module is None:
- inference_module = InferenceBuilder().load()
- mlp_input = inference_module._layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon)
- # activation
- if deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.GELU:
- activation = "gelu"
- elif deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.ReLU:
- activation = "relu"
- else:
- activation = ""
- # intermediate fc in FF
- intm_out = matmul_ext.matmul(mlp_input, weight_interm, bias=bias, activation=activation, use_triton=True)
- # output fc in FF
- ff_out = matmul_ext.matmul(
- intm_out,
- weight_out,
- bias=None,
- activation="", # bias added layer with residual_add + bias + layerNorm layer
- use_triton=True)
- return ff_out, mlp_input
- def qkv_gemm_func(
- input,
- weight,
- q_scale,
- bias,
- gamma,
- beta,
- epsilon,
- add_bias,
- q_int8,
- transposed_mode=False,
- use_triton_ln=True,
- ):
- assert not transposed_mode
- # residual add and layerNorm after attention
- if use_triton_ln:
- qkv_input = layer_norm(input, gamma, beta, epsilon)
- else:
- global inference_module
- if inference_module is None:
- inference_module = InferenceBuilder().load()
- qkv_input = inference_module.layer_norm(input, gamma, beta, epsilon)
- qkv_out = matmul_ext.matmul(qkv_input, weight, bias=(bias if add_bias else None), activation="", use_triton=True)
- return qkv_out, qkv_input
|