# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import torch import numpy as np from deepspeed.ops.op_builder import EvoformerAttnBuilder from deepspeed.accelerator import get_accelerator kernel_ = None def _attention(Q, K, V, bias1, bias2): assert Q.shape[-3] > 16, "seq_len must be greater than 16" O = torch.empty_like(Q, dtype=Q.dtype) assert get_accelerator().on_accelerator(Q), "Q must be on cuda" assert get_accelerator().on_accelerator(K), "K must be on cuda" assert get_accelerator().on_accelerator(V), "V must be on cuda" assert get_accelerator().on_accelerator(bias1), "bias1 must be on cuda" assert get_accelerator().on_accelerator(bias2), "bias2 must be on cuda" global kernel_ if kernel_ is None: kernel_ = EvoformerAttnBuilder().load() nheads = Q.shape[-2] nq = (Q.shape[-3] + 31) // 32 * 32 nb = np.prod(Q.shape[:-3]) lse = torch.empty((nb, nheads, nq), dtype=torch.float32, device=Q.device) kernel_.attention(Q, K, V, bias1, bias2, O, lse) return O, lse def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2, bias1_grad, bias2_grad): assert max(Q.shape[-1], V.shape[-1]) <= 64, "Hidden size is too large. Need to change kMax to a larger value" dQ = torch.empty_like(Q, dtype=Q.dtype) dK = torch.empty_like(K, dtype=K.dtype) dV = torch.empty_like(V, dtype=V.dtype) assert get_accelerator().on_accelerator(dO), "dO must be on cuda" assert get_accelerator().on_accelerator(Q), "Q must be on cuda" assert get_accelerator().on_accelerator(K), "K must be on cuda" assert get_accelerator().on_accelerator(V), "V must be on cuda" assert get_accelerator().on_accelerator(O), "O must be on cuda" global kernel_ if kernel_ is None: kernel_ = EvoformerAttnBuilder().load() delta = torch.empty_like(lse) if bias1_grad: dB1 = torch.zeros_like(bias1, dtype=torch.float32) else: dB1 = torch.tensor([], dtype=torch.float32, device=bias1.device) if bias2_grad: dB2 = torch.zeros_like(bias2, dtype=torch.float32) else: dB2 = torch.tensor([], dtype=torch.float32, device=bias2.device) kernel_.attention_bwd(dO, Q, K, V, O, lse, delta, bias1, bias2, dQ, dK, dV, dB1, dB2) return dQ, dK, dV, dB1.to(dO.dtype), dB2.to(dO.dtype) class EvoformerFusedAttention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, bias1=None, bias2=None): """ q, k, v: are in shape [*, L, H, D] """ bias1_ = bias1.contiguous() if bias1 is not None else torch.tensor([], dtype=q.dtype, device=q.device) bias2_ = bias2.contiguous() if bias2 is not None else torch.tensor([], dtype=q.dtype, device=q.device) q = q.contiguous() k = k.contiguous() v = v.contiguous() o, lse = _attention(q, k, v, bias1_, bias2_) ctx.save_for_backward(q, k, v, o, lse, bias1_, bias2_) return o @staticmethod def backward(ctx, grad_output): q, k, v, o, lse, bias1, bias2 = ctx.saved_tensors is_b1_grad = bias1.numel() != 0 and ctx.needs_input_grad[3] is_b2_grad = bias2.numel() != 0 and ctx.needs_input_grad[4] dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2, is_b1_grad, is_b2_grad) if not is_b1_grad: dB1 = None if not is_b2_grad: dB2 = None return dQ, dK, dV, dB1, dB2 def DS4Sci_EvoformerAttention(Q, K, V, biases): assert len(biases) <= 2 if (len(biases) == 0): biases.append(None) if (len(biases) == 1): biases.append(None) bias_1_shape = lambda x: (x.shape[0], x.shape[1], 1, 1, x.shape[2]) bias_2_shape = lambda x: (x.shape[0], 1, x.shape[3], x.shape[2], x.shape[2]) if biases[0] is not None: assert biases[0].shape == bias_1_shape(Q), "bias1 shape is incorrect" if biases[1] is not None: assert biases[1].shape == bias_2_shape(Q), "bias2 shape is incorrect" return EvoformerFusedAttention.apply(Q, K, V, biases[0], biases[1])