linear.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from ..config import DeepSpeedInferenceConfig
  6. from .base import BaseOp
  7. import deepspeed
  8. class LinearOp(BaseOp):
  9. def __init__(self, config: DeepSpeedInferenceConfig):
  10. super(LinearOp, self).__init__(config)
  11. try:
  12. if self.config.dtype in [torch.float16, torch.int8]:
  13. if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
  14. from deepspeed.ops.transformer.inference.triton.ops import linear_func as _triton_linear_func
  15. self.linear_func = _triton_linear_func
  16. triton_autotune = config.triton_autotune and config.layer_id == 0
  17. if triton_autotune:
  18. __class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size)
  19. else:
  20. self.linear_func = self.inference_module.linear_layer_fp16
  21. self.linear_func = self.inference_module.linear_layer_fp16
  22. elif self.config.dtype == torch.bfloat16:
  23. self.linear_func = self.inference_module.linear_layer_bf16
  24. else:
  25. self.linear_func = self.inference_module.linear_layer_fp32
  26. except AttributeError:
  27. self.linear_func = self.linear_fallback
  28. def linear_fallback(self, input, weight, bias, add_bias, do_flash_attn, num_heads, transpose):
  29. raise NotImplementedError
  30. def forward(self,
  31. input: torch.Tensor,
  32. weight: torch.Tensor,
  33. bias: torch.Tensor,
  34. add_bias: bool,
  35. do_flash_attn: bool,
  36. num_heads: int,
  37. external_cache: bool = None,
  38. num_layers: int = None):
  39. qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads,
  40. self.config.transposed_mode)
  41. return qkv_out
  42. @staticmethod
  43. def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16):
  44. from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul
  45. seqlen = [(min_seqlen + i)
  46. for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
  47. Fp16Matmul._read_autotune_table()
  48. for N in seqlen:
  49. A = torch.randn((N, hidden_size), dtype=dtype, device='cuda')
  50. B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda')
  51. matmul(A, B)
  52. Fp16Matmul._update_autotune_table()