DS4Sci_EvoformerAttention_bench.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. This script is to test the performance of the DS4Sci_EvoformerAttention op.
  6. To run the script,
  7. 1. Clone the CUTLASS repo. E.g. git clone https://github.com/NVIDIA/cutlass.git
  8. 2. Specify the CUTLASS_PATH environment variable. E.g. export CUTLASS_PATH=$(pwd)/cutlass
  9. 3. Run the script. E.g. python DS4Sci_EvoformerAttention_bench.py
  10. """
  11. import contextlib
  12. import torch
  13. from typing import List
  14. from torch.nn import functional as F
  15. from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
  16. from deepspeed.accelerator import get_accelerator
  17. def attention_reference(
  18. q_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
  19. k_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
  20. v_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
  21. biases: List[torch.Tensor],
  22. sm_scale: float) -> torch.Tensor:
  23. # Original shape: [*, Dim_Q, H, C_hid] -> Transpose to: [*, H, Dim_Q, C_hid]
  24. q = q_input.transpose(-2, -3)
  25. k = k_input.transpose(-2, -3)
  26. v = v_input.transpose(-2, -3)
  27. # Now, q, k, v are in shape: [*, H, Dim_Q, C_hid]
  28. # Transpose k to shape [*, H, C_hid, Dim_Q]
  29. k_t = k.transpose(-1, -2)
  30. # Now, q and k_t are in shapes: [*, H, Dim_Q, C_hid] and [*, H, C_hid, Dim_Q] respectively
  31. # [*, H, Dim_Q, Dim_Q]
  32. a = torch.matmul(q, k_t) * sm_scale
  33. for b in biases:
  34. a += b
  35. a = F.softmax(a, dim=-1)
  36. # Now, a is in shape [*, H, Dim_Q, Dim_Q], v is in shape [*, H, Dim_Q, C_hid]
  37. # Matmul operation results in [*, H, Dim_Q, C_hid]
  38. a_v = torch.matmul(a, v)
  39. # [*, Dim_Q, H, C_hid]
  40. o = a_v.transpose(-2, -3)
  41. return o
  42. dtype = torch.float16
  43. N = 256
  44. heads = 4
  45. dim = 32
  46. seq_len = 256
  47. @contextlib.contextmanager
  48. def cuda_timer(res_list):
  49. start = get_accelerator().Event(enable_timing=True)
  50. end = get_accelerator().Event(enable_timing=True)
  51. start.record()
  52. yield
  53. end.record()
  54. get_accelerator().synchronize()
  55. res_list.append(start.elapsed_time(end))
  56. def benchmark():
  57. ours_fw = []
  58. ours_bw = []
  59. baseline_fw = []
  60. baseline_bw = []
  61. for batch in range(1, 17):
  62. Q = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
  63. K = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
  64. V = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
  65. bias1 = torch.randn(batch, N, 1, 1, seq_len, dtype=dtype, device="cuda", requires_grad=False)
  66. bias2 = torch.randn(batch, 1, heads, seq_len, seq_len, dtype=dtype, device="cuda", requires_grad=True)
  67. # warm up
  68. DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2])
  69. with cuda_timer(ours_fw):
  70. out = DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2])
  71. d_out = torch.rand_like(out)
  72. with cuda_timer(ours_bw):
  73. out.backward(d_out)
  74. # warm up
  75. attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5))
  76. with cuda_timer(baseline_fw):
  77. ref_out = attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5))
  78. with cuda_timer(baseline_bw):
  79. ref_out.backward(d_out)
  80. print(f"batch size\tours (FW)\tbaseline (FW)\tours (BW)\tbaseline (BW)")
  81. for i in range(len(ours_fw)):
  82. print(f"{i+1}\t{ours_fw[i]}\t{baseline_fw[i]}\t{ours_bw[i]}\t{baseline_bw[i]}")
  83. benchmark()