123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from ..config import DeepSpeedInferenceConfig
- from .base import BaseOp
- import deepspeed
- class LinearOp(BaseOp):
- def __init__(self, config: DeepSpeedInferenceConfig):
- super(LinearOp, self).__init__(config)
- try:
- if self.config.dtype in [torch.float16, torch.int8]:
- if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
- from deepspeed.ops.transformer.inference.triton.ops import linear_func as _triton_linear_func
- self.linear_func = _triton_linear_func
- triton_autotune = config.triton_autotune and config.layer_id == 0
- if triton_autotune:
- __class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size)
- else:
- self.linear_func = self.inference_module.linear_layer_fp16
- self.linear_func = self.inference_module.linear_layer_fp16
- elif self.config.dtype == torch.bfloat16:
- self.linear_func = self.inference_module.linear_layer_bf16
- else:
- self.linear_func = self.inference_module.linear_layer_fp32
- except AttributeError:
- self.linear_func = self.linear_fallback
- def linear_fallback(self, input, weight, bias, add_bias, do_flash_attn, num_heads, transpose):
- raise NotImplementedError
- def forward(self,
- input: torch.Tensor,
- weight: torch.Tensor,
- bias: torch.Tensor,
- add_bias: bool,
- do_flash_attn: bool,
- num_heads: int,
- external_cache: bool = None,
- num_layers: int = None):
- qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads,
- self.config.transposed_mode)
- return qkv_out
- @staticmethod
- def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16):
- from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul
- seqlen = [(min_seqlen + i)
- for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
- Fp16Matmul._read_autotune_table()
- for N in seqlen:
- A = torch.randn((N, hidden_size), dtype=dtype, device='cuda')
- B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda')
- matmul(A, B)
- Fp16Matmul._update_autotune_table()
|