123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- // Copyright (c) Microsoft Corporation.
- // SPDX-License-Identifier: Apache-2.0
- // DeepSpeed Team
- /*
- DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
- https:github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
- */
- __global__ void softmax_fwd(TYPE *X __readonly __noalias __aligned(16),
- float scale,
- int *LUT __readonly __noalias __aligned(16),
- TYPE *RPE __readonly __noalias __aligned(16),
- TYPE *KP_M __readonly __noalias __aligned(16),
- TYPE *ATTN_M __readonly __noalias __aligned(16),
- int num_blocks,
- int sizemax,
- long stride_zx __multipleof(BLOCK),
- long stride_zrpe __multipleof(BLOCK),
- int stride_hrpe __multipleof(BLOCK),
- int stride_srpe __multipleof(BLOCK),
- int stride_zkpm __multipleof(BLOCK),
- int stride_zattnm __multipleof(BLOCK)){
- int pidhm = get_program_id(0);
- int pidz = get_program_id(1);
- // create index ranges
- int rxm = pidhm % BLOCK;
- int rbm = pidhm / BLOCK;
- int rxn[TN] = (0 ... TN) % BLOCK;
- int rbn[TN] = (0 ... TN) / BLOCK;
- // extract information from look-up table
- int* header = LUT + rbm * 2;
- int size = *(header + 0);
- int offset = *(header + 1);
- bool check[TN] = rbn < size;
- int rbmn[TN] = check ? rbn : size - 1;
- // block id and column id
- long blockid [TN] = *(LUT + offset + rbmn*4 + 0);
- long columnid[TN] = *(LUT + offset + rbmn*4 + 1);
- long rowid [TN] = *(LUT + offset + rbmn*4 + 2);
- long headid [TN] = *(LUT + offset + rbmn*4 + 3);
- // pointers to X
- TYPE* px[TN] = X + pidz * stride_zx
- + blockid * BLOCK * BLOCK
- + rxm * BLOCK
- + rxn;
- #ifdef APPLY_RPE
- // pointers to relative position embedding
- TYPE* prpe[TN] = RPE + pidz * stride_zrpe
- + headid * stride_hrpe
- + columnid * BLOCK
- + rowid * BLOCK * stride_srpe
- + rxm * stride_srpe
- + rxn;
- #endif
- #ifdef APPLY_KP_MASK
- // pointers to key padding mask
- TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm
- + columnid * BLOCK
- + rxn;
- #endif
- #ifdef APPLY_ATTN_MASK
- // pointers to attention mask
- TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK
- + rowid * BLOCK * stride_zattnm
- + rxm * stride_zattnm
- + rxn;
- #endif
- // load input
- TYPE x[TN] = check ? *px : -INFINITY;
- #ifdef APPLY_RPE
- // load relative position embedding
- TYPE rpe[TN] = check ? *prpe : 0;
- #endif
- #ifdef APPLY_KP_MASK
- // load key-padding mask
- TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
- #endif
- #ifdef APPLY_ATTN_MASK
- // load attention mask
- TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
- #endif
- // compute softmax in float
- #ifdef APPLY_RPE
- float Frpe[TN] = rpe;
- #endif
- #ifdef APPLY_KP_MASK
- float Fkp_m[TN] = kp_m;
- #endif
- #ifdef APPLY_ATTN_MASK
- float Fattn_m[TN] = attn_m;
- #endif
- #ifdef KP_MASK_MUL
- Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0;
- #endif
- #ifdef ATTN_MASK_MUL
- Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0;
- #endif
- float Fx[TN] = x;
- #ifdef APPLY_SCALE
- Fx = Fx * scale; // apply scale
- #endif
- #ifdef APPLY_RPE
- Fx = Fx + Frpe; // apply relative position embedding
- #endif
- #ifdef APPLY_KP_MASK
- Fx = Fx + Fkp_m; // apply key padding mask
- #endif
- #ifdef APPLY_ATTN_MASK
- Fx = Fx + Fattn_m; // apply attention mask
- #endif
- float Fxmax = Fx[max];
- float Fy[TN] = exp(Fx - Fxmax);
- float Fysum = (check ? Fy : 0)[+];
- // write-back in half/float
- TYPE y[TN] = Fy;
- TYPE ysum = Fysum;
- *?(check)px = y / ysum;
- }
|