mlp.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import math
  6. import torch.nn as nn
  7. from deepspeed.accelerator import get_accelerator
  8. from deepspeed import comm as dist
  9. from ..op_binding import MLPGemmOp, VectorMatMulOp, GELUGemmOp, ResidualAddOp
  10. class TritonMLP(nn.Module):
  11. def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, mlp_extra_grouping=False):
  12. super(TritonMLP, self).__init__()
  13. self.config = config
  14. data_type = self.config.dtype
  15. data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
  16. device = get_accelerator().current_device_name()
  17. self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
  18. requires_grad=False)
  19. self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
  20. requires_grad=False)
  21. intm_size_per_partition = self.config.intermediate_size // self.config.mp_size
  22. self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
  23. intm_size_per_partition,
  24. dtype=data_type,
  25. device=device),
  26. requires_grad=False)
  27. self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device),
  28. requires_grad=False)
  29. self.output_w = nn.Parameter(torch.empty(intm_size_per_partition,
  30. self.config.hidden_size,
  31. dtype=data_type,
  32. device=device),
  33. requires_grad=False)
  34. self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
  35. requires_grad=False)
  36. # used for quantization
  37. self.q_scales = q_scales
  38. self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups
  39. self.merge_count = int(math.log2(merge_count))
  40. self.mp_group = mp_group
  41. self.mlp_gemm_func = MLPGemmOp(config)
  42. self.vector_matmul_func = VectorMatMulOp(config)
  43. self.fused_gemm_gelu = GELUGemmOp(config)
  44. self.residual_add_func = ResidualAddOp(config)
  45. def forward(self, input, residual, residual_norm, bias):
  46. residual_add = None
  47. if self.attn_nw is None:
  48. output = self.fused_gemm_gelu(input=residual_norm,
  49. weight=self.inter_w,
  50. bias=self.inter_b,
  51. weight_out=self.output_w)
  52. else:
  53. output, residual_add = self.mlp_gemm_func(input=input,
  54. residual=residual,
  55. input_bias=bias,
  56. weight_interm=self.inter_w,
  57. weight_out=self.output_w,
  58. bias=self.inter_b,
  59. gamma=self.attn_nw,
  60. beta=self.attn_nb)
  61. residual = self.residual_add_func(hidden_state=output,
  62. residual=residual,
  63. attention_output=input,
  64. attention_bias=bias if bias is not None else self.output_b,
  65. final_bias=self.output_b,
  66. add_bias=bias is not None,
  67. residual_add=residual_add)
  68. if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
  69. dist.all_reduce(residual, group=self.mp_group)
  70. return residual