test_DS4Sci_EvoformerAttention.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from typing import List
  5. import pytest
  6. import torch
  7. from torch.nn import functional as F
  8. import deepspeed
  9. from deepspeed.ops.op_builder import EvoformerAttnBuilder
  10. from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
  11. from deepspeed.accelerator import get_accelerator
  12. from unit.util import skip_on_arch
  13. if not deepspeed.ops.__compatible_ops__[EvoformerAttnBuilder.NAME]:
  14. pytest.skip("DS4Sci_EvoformerAttention ops are not available on this system", allow_module_level=True)
  15. def attention_reference(
  16. q_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
  17. k_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
  18. v_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
  19. biases: List[torch.Tensor],
  20. sm_scale: float) -> torch.Tensor:
  21. q = q_input.transpose(-2, -3)
  22. k = k_input.transpose(-2, -3)
  23. v = v_input.transpose(-2, -3)
  24. k_t = k.transpose(-1, -2)
  25. a = torch.matmul(q, k_t) * sm_scale
  26. for b in biases:
  27. a += b
  28. a = F.softmax(a, dim=-1)
  29. a_v = torch.matmul(a, v)
  30. o = a_v.transpose(-2, -3)
  31. return o
  32. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  33. @pytest.mark.parametrize("tensor_shape", [(1, 256, 256, 4, 32), (1, 512, 256, 8, 8)])
  34. def test_DS4Sci_EvoformerAttention(dtype, tensor_shape):
  35. skip_on_arch(8 if dtype == torch.bfloat16 else 7)
  36. batch, n, seq_len, heads, dim = tensor_shape
  37. Q = torch.randn(batch,
  38. n,
  39. seq_len,
  40. heads,
  41. dim,
  42. dtype=dtype,
  43. device=get_accelerator().device_name(),
  44. requires_grad=True)
  45. K = torch.randn(batch,
  46. n,
  47. seq_len,
  48. heads,
  49. dim,
  50. dtype=dtype,
  51. device=get_accelerator().device_name(),
  52. requires_grad=True)
  53. V = torch.randn(batch,
  54. n,
  55. seq_len,
  56. heads,
  57. dim,
  58. dtype=dtype,
  59. device=get_accelerator().device_name(),
  60. requires_grad=True)
  61. bias1 = torch.randn(batch,
  62. n,
  63. 1,
  64. 1,
  65. seq_len,
  66. dtype=dtype,
  67. device=get_accelerator().device_name(),
  68. requires_grad=True)
  69. bias2 = torch.randn(batch,
  70. 1,
  71. heads,
  72. seq_len,
  73. seq_len,
  74. dtype=dtype,
  75. device=get_accelerator().device_name(),
  76. requires_grad=True)
  77. dummy_out = torch.rand_like(Q, dtype=dtype, device=get_accelerator().device_name())
  78. ref_out = attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5))
  79. ref_out.backward(dummy_out)
  80. ref_dv, V.grad = V.grad.clone(), None
  81. ref_dk, K.grad = K.grad.clone(), None
  82. ref_dq, Q.grad = Q.grad.clone(), None
  83. ref_db1, bias1.grad = bias1.grad.clone(), None
  84. ref_db2, bias2.grad = bias2.grad.clone(), None
  85. out = DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2])
  86. out.backward(dummy_out)
  87. dv, v_grad = V.grad.clone(), None
  88. dk, k_grad = K.grad.clone(), None
  89. dq, q_grad = Q.grad.clone(), None
  90. db1, bias1.grad = bias1.grad.clone(), None
  91. db2, bias2.grad = bias2.grad.clone(), None
  92. assert torch.allclose(ref_out, out, atol=2e-2, rtol=0), f"\n{ref_out} \n {out}"
  93. assert torch.allclose(ref_dv, dv, atol=2e-2, rtol=0), f"\n{ref_dv} \n {dv}"
  94. assert torch.allclose(ref_dk, dk, atol=2e-2, rtol=0), f"\n{ref_dk} \n {dk}"
  95. assert torch.allclose(ref_dq, dq, atol=2e-2, rtol=0), f"\n{ref_dq} \n {dq}"
  96. assert torch.allclose(ref_db1, db1, atol=2e-2, rtol=1e-2), f"{ref_db1} \n {db1}"
  97. assert torch.allclose(ref_db2, db2, atol=2e-2, rtol=1e-2), f"{ref_db2} \n {db2}"