// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a // https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py __global__ void softmax_bwd(TYPE * X __readonly __noalias __aligned(16), float scale, TYPE* DX __readonly __noalias __aligned(16), int* LUT, int sizemax, long stride_zx __multipleof(BLOCK), long stride_zdx __multipleof(BLOCK)) { int pidhm = get_program_id(0); int pidz = get_program_id(1); // create index ranges int rxm = pidhm % BLOCK; int rbm = pidhm / BLOCK; int rxn[TN] = (0 ... TN) % BLOCK; int rbn[TN] = (0 ... TN) / BLOCK; // extract information from look-up table int* header = LUT + rbm * 2; int size = *(header + 0); int offset = *(header + 1); // bounds checking on lut bool check[TN] = rbn < size; int rbmn[TN] = check ? rbn : size - 1; // initialize pointers to block-sparse input long blockid[TN] = *(LUT + offset + rbmn*4); TYPE* px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn; TYPE* pdx[TN] = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn; // compute fused softmax backward TYPE x[TN] = check ? *px : 0; TYPE dx[TN] = check ? *pdx : 0; float Fdx[TN] = dx; float Fx[TN] = x; float Fxdx[TN] = Fdx*Fx; float Fxdxsum = Fxdx[+]; float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale; TYPE y[TN] = Fy; // write-back *? (check)pdx = y; }