DS4Sci_EvoformerAttention_bench.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. This script is to test the correctness of the DS4Sci_EvoformerAttention op.
  6. To run the script,
  7. 1. Clone the CUTLASS repo. E.g. git clone https://github.com/NVIDIA/cutlass.git
  8. 2. Specify the CUTLASS_PATH environment variable. E.g. export CUTLASS_PATH=$(pwd)/cutlass
  9. 3. Run the script. E.g. python DS4Sci_EvoformerAttention_bench.py
  10. """
  11. import contextlib
  12. import torch
  13. from typing import List
  14. from torch.nn import functional as F
  15. from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
  16. from deepspeed.accelerator import get_accelerator
  17. def attention_reference(
  18. q_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
  19. k_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
  20. v_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
  21. biases: List[torch.Tensor],
  22. sm_scale: float) -> torch.Tensor:
  23. # Original shape: [*, Dim_Q, H, C_hid] -> Transpose to: [*, H, Dim_Q, C_hid]
  24. q = q_input.transpose(-2, -3)
  25. k = k_input.transpose(-2, -3)
  26. v = v_input.transpose(-2, -3)
  27. # Now, q, k, v are in shape: [*, H, Dim_Q, C_hid]
  28. # Transpose k to shape [*, H, C_hid, Dim_Q]
  29. k_t = k.transpose(-1, -2)
  30. # Now, q and k_t are in shapes: [*, H, Dim_Q, C_hid] and [*, H, C_hid, Dim_Q] respectively
  31. # [*, H, Dim_Q, Dim_Q]
  32. a = torch.matmul(q, k_t) * sm_scale
  33. for b in biases:
  34. a += b
  35. a = F.softmax(a, dim=-1)
  36. # Now, a is in shape [*, H, Dim_Q, Dim_Q], v is in shape [*, H, Dim_Q, C_hid]
  37. # Matmul operation results in [*, H, Dim_Q, C_hid]
  38. a_v = torch.matmul(a, v)
  39. # [*, Dim_Q, H, C_hid]
  40. o = a_v.transpose(-2, -3)
  41. return o
  42. dtype = torch.float16
  43. batch = 1
  44. N = 256
  45. heads = 4
  46. dim = 32
  47. seq_len = 256
  48. @contextlib.contextmanager
  49. def cuda_timer(res_list):
  50. start = get_accelerator().Event(enable_timing=True)
  51. end = get_accelerator().Event(enable_timing=True)
  52. start.record()
  53. yield
  54. end.record()
  55. get_accelerator().synchronize()
  56. res_list.append(start.elapsed_time(end))
  57. def benchmark():
  58. ours_fw = []
  59. ours_bw = []
  60. baseline_fw = []
  61. baseline_bw = []
  62. for batch_size in range(1, 17):
  63. Q = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
  64. K = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
  65. V = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
  66. bias1 = torch.randn(batch, N, 1, 1, seq_len, dtype=dtype, device="cuda", requires_grad=True)
  67. bias2 = torch.randn(batch, 1, heads, seq_len, seq_len, dtype=dtype, device="cuda", requires_grad=True)
  68. # warm up
  69. DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2])
  70. with cuda_timer(ours_fw):
  71. out = DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2])
  72. d_out = torch.rand_like(out)
  73. with cuda_timer(ours_bw):
  74. out.backward(d_out)
  75. # warm up
  76. attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5))
  77. with cuda_timer(baseline_fw):
  78. ref_out = attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5))
  79. with cuda_timer(baseline_bw):
  80. ref_out.backward(d_out)
  81. print(f"batch size\tours (FW)\tbaseline (FW)\tours (BW)\tbaseline (BW)")
  82. for i in range(len(ours_fw)):
  83. print(f"{i+1}\t{ours_fw[i]}\t{baseline_fw[i]}\t{ours_bw[i]}\t{baseline_bw[i]}")
  84. benchmark()