// Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team /* DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a https:github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py */ __global__ void softmax_fwd(TYPE *X __readonly __noalias __aligned(16), float scale, int *LUT __readonly __noalias __aligned(16), TYPE *RPE __readonly __noalias __aligned(16), TYPE *KP_M __readonly __noalias __aligned(16), TYPE *ATTN_M __readonly __noalias __aligned(16), int num_blocks, int sizemax, long stride_zx __multipleof(BLOCK), long stride_zrpe __multipleof(BLOCK), int stride_hrpe __multipleof(BLOCK), int stride_srpe __multipleof(BLOCK), int stride_zkpm __multipleof(BLOCK), int stride_zattnm __multipleof(BLOCK)){ int pidhm = get_program_id(0); int pidz = get_program_id(1); // create index ranges int rxm = pidhm % BLOCK; int rbm = pidhm / BLOCK; int rxn[TN] = (0 ... TN) % BLOCK; int rbn[TN] = (0 ... TN) / BLOCK; // extract information from look-up table int* header = LUT + rbm * 2; int size = *(header + 0); int offset = *(header + 1); bool check[TN] = rbn < size; int rbmn[TN] = check ? rbn : size - 1; // block id and column id long blockid [TN] = *(LUT + offset + rbmn*4 + 0); long columnid[TN] = *(LUT + offset + rbmn*4 + 1); long rowid [TN] = *(LUT + offset + rbmn*4 + 2); long headid [TN] = *(LUT + offset + rbmn*4 + 3); // pointers to X TYPE* px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn; #ifdef APPLY_RPE // pointers to relative position embedding TYPE* prpe[TN] = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn; #endif #ifdef APPLY_KP_MASK // pointers to key padding mask TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn; #endif #ifdef APPLY_ATTN_MASK // pointers to attention mask TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn; #endif // load input TYPE x[TN] = check ? *px : -INFINITY; #ifdef APPLY_RPE // load relative position embedding TYPE rpe[TN] = check ? *prpe : 0; #endif #ifdef APPLY_KP_MASK // load key-padding mask TYPE kp_m[TN] = check ? *pkp_m : -INFINITY; #endif #ifdef APPLY_ATTN_MASK // load attention mask TYPE attn_m[TN] = check ? *pattn_m : -INFINITY; #endif // compute softmax in float #ifdef APPLY_RPE float Frpe[TN] = rpe; #endif #ifdef APPLY_KP_MASK float Fkp_m[TN] = kp_m; #endif #ifdef APPLY_ATTN_MASK float Fattn_m[TN] = attn_m; #endif #ifdef KP_MASK_MUL Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0; #endif #ifdef ATTN_MASK_MUL Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0; #endif float Fx[TN] = x; #ifdef APPLY_SCALE Fx = Fx * scale; // apply scale #endif #ifdef APPLY_RPE Fx = Fx + Frpe; // apply relative position embedding #endif #ifdef APPLY_KP_MASK Fx = Fx + Fkp_m; // apply key padding mask #endif #ifdef APPLY_ATTN_MASK Fx = Fx + Fattn_m; // apply attention mask #endif float Fxmax = Fx[max]; float Fy[TN] = exp(Fx - Fxmax); float Fysum = (check ? Fy : 0)[+]; // write-back in half/float TYPE y[TN] = Fy; TYPE ysum = Fysum; *?(check)px = y / ysum; }