evoformer_attn.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import numpy as np
  6. from deepspeed.ops.op_builder import EvoformerAttnBuilder
  7. from deepspeed.accelerator import get_accelerator
  8. kernel_ = None
  9. def _attention(Q, K, V, bias1, bias2):
  10. assert Q.shape[-3] > 16, "seq_len must be greater than 16"
  11. O = torch.empty_like(Q, dtype=Q.dtype)
  12. assert get_accelerator().on_accelerator(Q), "Q must be on cuda"
  13. assert get_accelerator().on_accelerator(K), "K must be on cuda"
  14. assert get_accelerator().on_accelerator(V), "V must be on cuda"
  15. assert get_accelerator().on_accelerator(bias1), "bias1 must be on cuda"
  16. assert get_accelerator().on_accelerator(bias2), "bias2 must be on cuda"
  17. global kernel_
  18. if kernel_ is None:
  19. kernel_ = EvoformerAttnBuilder().load()
  20. nheads = Q.shape[-2]
  21. nq = (Q.shape[-3] + 31) // 32 * 32
  22. nb = np.prod(Q.shape[:-3])
  23. lse = torch.empty((nb, nheads, nq), dtype=torch.float32, device=Q.device)
  24. kernel_.attention(Q, K, V, bias1, bias2, O, lse)
  25. return O, lse
  26. def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2, bias1_grad, bias2_grad):
  27. assert max(Q.shape[-1], V.shape[-1]) <= 64, "Hidden size is too large. Need to change kMax to a larger value"
  28. dQ = torch.empty_like(Q, dtype=Q.dtype)
  29. dK = torch.empty_like(K, dtype=K.dtype)
  30. dV = torch.empty_like(V, dtype=V.dtype)
  31. assert get_accelerator().on_accelerator(dO), "dO must be on cuda"
  32. assert get_accelerator().on_accelerator(Q), "Q must be on cuda"
  33. assert get_accelerator().on_accelerator(K), "K must be on cuda"
  34. assert get_accelerator().on_accelerator(V), "V must be on cuda"
  35. assert get_accelerator().on_accelerator(O), "O must be on cuda"
  36. global kernel_
  37. if kernel_ is None:
  38. kernel_ = EvoformerAttnBuilder().load()
  39. delta = torch.empty_like(lse)
  40. if bias1_grad:
  41. dB1 = torch.zeros_like(bias1, dtype=torch.float32)
  42. else:
  43. dB1 = torch.tensor([], dtype=torch.float32, device=bias1.device)
  44. if bias2_grad:
  45. dB2 = torch.zeros_like(bias2, dtype=torch.float32)
  46. else:
  47. dB2 = torch.tensor([], dtype=torch.float32, device=bias2.device)
  48. kernel_.attention_bwd(dO, Q, K, V, O, lse, delta, bias1, bias2, dQ, dK, dV, dB1, dB2)
  49. return dQ, dK, dV, dB1.to(dO.dtype), dB2.to(dO.dtype)
  50. class EvoformerFusedAttention(torch.autograd.Function):
  51. @staticmethod
  52. def forward(ctx, q, k, v, bias1=None, bias2=None):
  53. """
  54. q, k, v: are in shape [*, L, H, D]
  55. """
  56. bias1_ = bias1.contiguous() if bias1 is not None else torch.tensor([], dtype=q.dtype, device=q.device)
  57. bias2_ = bias2.contiguous() if bias2 is not None else torch.tensor([], dtype=q.dtype, device=q.device)
  58. q = q.contiguous()
  59. k = k.contiguous()
  60. v = v.contiguous()
  61. o, lse = _attention(q, k, v, bias1_, bias2_)
  62. ctx.save_for_backward(q, k, v, o, lse, bias1_, bias2_)
  63. return o
  64. @staticmethod
  65. def backward(ctx, grad_output):
  66. q, k, v, o, lse, bias1, bias2 = ctx.saved_tensors
  67. is_b1_grad = bias1.numel() != 0 and ctx.needs_input_grad[3]
  68. is_b2_grad = bias2.numel() != 0 and ctx.needs_input_grad[4]
  69. dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2, is_b1_grad, is_b2_grad)
  70. if not is_b1_grad:
  71. dB1 = None
  72. if not is_b2_grad:
  73. dB2 = None
  74. return dQ, dK, dV, dB1, dB2
  75. def DS4Sci_EvoformerAttention(Q, K, V, biases):
  76. assert len(biases) <= 2
  77. if (len(biases) == 0):
  78. biases.append(None)
  79. if (len(biases) == 1):
  80. biases.append(None)
  81. bias_1_shape = lambda x: (x.shape[0], x.shape[1], 1, 1, x.shape[2])
  82. bias_2_shape = lambda x: (x.shape[0], 1, x.shape[3], x.shape[2], x.shape[2])
  83. if biases[0] is not None:
  84. assert biases[0].shape == bias_1_shape(Q), "bias1 shape is incorrect"
  85. if biases[1] is not None:
  86. assert biases[1].shape == bias_2_shape(Q), "bias2 shape is incorrect"
  87. return EvoformerFusedAttention.apply(Q, K, V, biases[0], biases[1])