123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- // DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
- // https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
- __global__ void softmax_bwd(TYPE * X __readonly __noalias __aligned(16),
- float scale,
- TYPE* DX __readonly __noalias __aligned(16),
- int* LUT,
- int sizemax,
- long stride_zx __multipleof(BLOCK),
- long stride_zdx __multipleof(BLOCK)) {
- int pidhm = get_program_id(0);
- int pidz = get_program_id(1);
- // create index ranges
- int rxm = pidhm % BLOCK;
- int rbm = pidhm / BLOCK;
- int rxn[TN] = (0 ... TN) % BLOCK;
- int rbn[TN] = (0 ... TN) / BLOCK;
- // extract information from look-up table
- int* header = LUT + rbm * 2;
- int size = *(header + 0);
- int offset = *(header + 1);
- // bounds checking on lut
- bool check[TN] = rbn < size;
- int rbmn[TN] = check ? rbn : size - 1;
- // initialize pointers to block-sparse input
- long blockid[TN] = *(LUT + offset + rbmn*4);
- TYPE* px[TN] = X + pidz * stride_zx
- + blockid * BLOCK * BLOCK
- + rxm * BLOCK
- + rxn;
- TYPE* pdx[TN] = DX + pidz * stride_zdx
- + blockid * BLOCK * BLOCK
- + rxm * BLOCK
- + rxn;
- // compute fused softmax backward
- TYPE x[TN] = check ? *px : 0;
- TYPE dx[TN] = check ? *pdx : 0;
- float Fdx[TN] = dx;
- float Fx[TN] = x;
- float Fxdx[TN] = Fdx*Fx;
- float Fxdxsum = Fxdx[+];
- float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
- TYPE y[TN] = Fy;
- // write-back
- *? (check)pdx = y;
- }
|