ops.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import deepspeed
  5. from deepspeed.ops.op_builder import InferenceBuilder
  6. import deepspeed.ops.transformer.inference.triton.matmul_ext as matmul_ext
  7. from deepspeed.ops.transformer.inference.triton.layer_norm import layer_norm, layer_norm_residual
  8. inference_module = None
  9. def vector_matmul_func(input, weight, async_op, q_scale, q_int8, transposed_mode):
  10. assert not transposed_mode and not async_op and not q_int8
  11. return matmul_ext.matmul(input, weight, bias=None, activation="", use_triton=True)
  12. def fused_gemm_gelu(input,
  13. weight,
  14. weight_scale,
  15. bias,
  16. weight_out,
  17. weight_out_scale,
  18. epsilon,
  19. pre_layer_norm,
  20. q_int8,
  21. async_op,
  22. transposed_mode,
  23. use_triton_ln=True):
  24. assert not transposed_mode
  25. # activation
  26. activation = "gelu"
  27. # intermediate fc in FF
  28. intm_out = matmul_ext.matmul(input, weight, bias=bias, activation=activation, use_triton=True)
  29. # output fc in FF
  30. ff_out = matmul_ext.matmul(
  31. intm_out,
  32. weight_out,
  33. bias=None,
  34. activation="", # bias added layer with residual_add + bias + layerNorm layer
  35. use_triton=True)
  36. return ff_out
  37. def linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads, transposed_mode=False):
  38. assert not transposed_mode and not do_flash_attn
  39. qkv_out = matmul_ext.matmul(input, weight, bias=(bias if add_bias else None), activation="", use_triton=True)
  40. return qkv_out
  41. def mlp_gemm_func(input,
  42. residual,
  43. input_bias,
  44. weight_interm,
  45. weight_out,
  46. bias,
  47. gamma,
  48. beta,
  49. epsilon,
  50. pre_layer_norm,
  51. mlp_after_attn,
  52. weight_interm_scale,
  53. weight_out_scale,
  54. q_int8,
  55. mlp_act_func_type,
  56. transposed_mode,
  57. use_triton_ln=True):
  58. assert not transposed_mode
  59. # residual add and layerNorm after attention
  60. if use_triton_ln:
  61. mlp_input = layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon)
  62. else:
  63. global inference_module
  64. if inference_module is None:
  65. inference_module = InferenceBuilder().load()
  66. mlp_input = inference_module._layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon)
  67. # activation
  68. if deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.GELU:
  69. activation = "gelu"
  70. elif deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.ReLU:
  71. activation = "relu"
  72. else:
  73. activation = ""
  74. # intermediate fc in FF
  75. intm_out = matmul_ext.matmul(mlp_input, weight_interm, bias=bias, activation=activation, use_triton=True)
  76. # output fc in FF
  77. ff_out = matmul_ext.matmul(
  78. intm_out,
  79. weight_out,
  80. bias=None,
  81. activation="", # bias added layer with residual_add + bias + layerNorm layer
  82. use_triton=True)
  83. return ff_out, mlp_input
  84. def qkv_gemm_func(
  85. input,
  86. weight,
  87. q_scale,
  88. bias,
  89. gamma,
  90. beta,
  91. epsilon,
  92. add_bias,
  93. q_int8,
  94. transposed_mode=False,
  95. use_triton_ln=True,
  96. ):
  97. assert not transposed_mode
  98. # residual add and layerNorm after attention
  99. if use_triton_ln:
  100. qkv_input = layer_norm(input, gamma, beta, epsilon)
  101. else:
  102. global inference_module
  103. if inference_module is None:
  104. inference_module = InferenceBuilder().load()
  105. qkv_input = inference_module.layer_norm(input, gamma, beta, epsilon)
  106. qkv_out = matmul_ext.matmul(qkv_input, weight, bias=(bias if add_bias else None), activation="", use_triton=True)
  107. return qkv_out, qkv_input