title: "DeepSpeed Sparse Attention" excerpt: "" date: 2020-09-09 01:00:00
Attention-based deep learning models such as the transformers are highly effective in capturing the relationship between tokens in an input sequence, even across long distances. As a result, they are used with text, image, and sound-based inputs, where the sequence length can be in thousands of tokens. However, despite the effectiveness of attention modules to capture long term dependencies, in practice, their application to long sequence input is limited by compute and memory requirements of the attention computation that grow quadratically, O(n^2)
, with the sequence length n
.
To address this limitation, DeepSpeed offers a suite of sparse attention kernels --an instrumental technology that can reduce the compute and memory requirement of attention computation by orders-of-magnitude via block-sparse computation. The suite not only alleviates the memory bottleneck of attention calculation, but also performs sparse computation efficiently. Its APIs allow convenient integration with any transformer-based models. Along with providing a wide spectrum of sparsity structures, it has the flexibility of handling any user-defined block-sparse structures. More specifically, sparse attention (SA) can be designed to compute local attention between nearby tokens, or global attention via summary tokens computed with local attention. Moreover, SA can also allow random attention, or any combination of local, global, and random attention as shown in the following figure with blue, orange, and green blocks, respectively. As a result, SA decreases the memory footprint to O(wn)
, in which 1 < w < n
is a parameter, whose value depends on the attention structure.
This library is PyTorch based and develops required kernels through Triton platform; kernels are not written in CUDA, which leaves the door open for CPU/OpenCL/Vulkan support in the future. The library is an extension to DeepSpeed and can be used through DeepSpeed as well as stand alone.
Block-sparse computations handled by DeepSpeed Sparse Attention kernels are illustrated in following figures for forward and backward passes respectively. In the figures, S
stands for a block-sparse matrix
and D
a dense matrix
.
To learn more about Sparsity Config, and also how to use this library, please check our tutorial that provides detailed information about it.
Fixed
sparsity, and two implementations have comparable accuracy. On system performance, SA outperforms Longformer both in training and inference:
Model | Local Window Size | BPC | Train Step | Time Per Iteration | Time Improvement | Accuracy improvement |
---|---|---|---|---|---|---|
RoBERTa Checkpoint | 2.5326 | |||||
Longformer | 512 | 2.6535 | 0 | 1.47 | 1.01 | |
Sparse Attention | 2.6321 | |||||
Longformer | 1.6708 | 3k | 1.6280 | 1.01 | ||
Sparse Attention | 1.6613 | 1.1059 | ||||
Longformer | 64 | 5.7840 | 0 | 1.31 | 1.46 | |
Sparse Attention | 3.9737 | |||||
Longformer | 2.0466 | 3k | 1.4855 | 1.09 | ||
Sparse Attention | 1.8693 | 1.1372 |
2,048
Sequence Length and batch size 1
. In this experiment, we noticed up to 3.13X
speed up replacing Bert Attention with DeepSpeed Sparse Attention instead of Longformer Attention. Following table shows the complete result.Local Window Size | Time Improvement |
---|---|
512 | 3.13 |
256 | 2.29 |
128 | 2.16 |
64 | 1.5 |
32 | 1.24 |
16 | 1.23 |
We also define a template to have variable
structure (top figure), which can be used to simply customize any block-sparse random/local/global attention pattern. In addition to this list, user can add any other sparsity structure as described in tutorial section.