Two recent papers introduce approximations to Transformer self-attention with runtime $\ll L^2$.

  1. Masked Language Modeling for Proteins via Linearly Scalable Long-Context Transformers introduces Fast attention via orthogonal random features (FAVOR).
  2. Linformer: Self-Attention with Linear Complexity introduces linear self-attention.

I’m going to summarize the main contribution of each paper, but you should definitely go read them for details such as theorems and theoretical insights. I also write everything using the notation from the FAVOR paper for consistency.

Transformer dot-product attention

Let $L$ be the size of an input sequence of tokens. Then transformer dot-product attention is a mapping which accepts matrices $\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V} \in \mathbb{R}^{L×d}$ as input where $d$ is the hidden dimension. Matrices $\mathbf{Q}$, $\mathbf{K}$, and $\mathbf{V}$ are intermediate representations of the input and their rows can be interpreted as queries, keys and values of the continuous dictionary data structure respectively. Transformer dot-product attention is defined as

\[\operatorname{Att}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathbf{D}^{-1}\mathbf{AV}\]

where the attention matrix is $\mathbf{A} = \operatorname{exp}\left(\frac{\mathbf{QK^T}}{\sqrt{d}}\right)$ and $\mathbf{D} = \operatorname{diag}(\mathbf{A1_L})$ is the normalizing factor. For details, see the original paper or this post from Harvard NLP.

While dot-product attention has proven very useful, its runtime and memory scale as $O(L^2d + Ld)$ and $O(L^2)$, respectively, because $\mathbf{A}\in \mathbb{R}^{L\times L}$ must be computed and stored explicitly. In practice, applications in natural language processing commonly limit the sequence length to 512 or 1024. This is also a significant constraint when modeling proteins. For example, Streptococcus pyogenes CRISPR-Cas9 is 1368 amino acids long.

Fast attention via orthogonal random features (FAVOR)

The attention matrix $\mathbf{A}$ can be decomposed as \(\mathbf{A} = \mathbf{D_Q B D_K}\) with

\[\mathbf{D_T} = \operatorname{diag}\left[ \operatorname{exp}\left(\frac{\|\mathbf{T}_1\|_2^2}{2\sqrt{d}}\right)\ldots \operatorname{exp}\left(\frac{\|\mathbf{T}_L\|_2^2}{2\sqrt{d}}\right)\right] \forall \: \mathbf{T} \in \{\mathbf{Q}, \mathbf{K}\}\]

and

\[\mathbf{B} \in \mathbb{R}^{L \times L}, B_{i, j} = \operatorname{exp}\left(-\frac{\|\mathbf{Q}_i - \mathbf{K}_j\|}{2\sqrt{d}}\right)\]

Naively, $\mathbf{D_T}$ requires $O(Ld)$ time to compute while $\mathbf{B}$ requires $O(L^2d)$, arriving at the overall time complexity of $O(L^2d + Ld)$ for dot-product attention.

FAVOR is a fast method for approximating $\mathbf{B}$. $\mathbf{B}$ is a the Gaussian (squared-exponential) kernel matrix between the rows of $\mathbf{Q}$ and $\mathbf{K}$ with $\sigma = d ^ {\frac{1}{4}}$. Like most kernels used in machine learning, the Gaussian kernel has a fast random feature approximation. Given a random mapping $\phi: \mathbb{R}^d \to \mathbb{R}^M$ of the form

\[\phi(\mathbf{x}) = \sqrt{\frac{2}{M}}\operatorname{cos}(\mathbf{Wx} + \mathbf{b})^T\]

where $W_{i, j} \sim \mathcal{N}(0, \sigma^2)$ and $b_i ~\sim \operatorname{Unif}(0, 2\pi)$. As described in the paper, choosing orthogonal random features instead of sampling independently decreases the variance of the approximation.

\[K(\mathbf{x}, \mathbf{y}) = \mathbb{E}\left[\phi(\mathbf{x})^T\phi(\mathbf{y})\right]\]

Define randomly-featurized keys and queries as $\mathbf{\hat{Q}} = \sqrt{\frac{2}{M}}\operatorname{cos}(\mathbf{WQ}^T + \mathbf{b})^T$ and $\mathbf{\hat{K}} = \sqrt{\frac{2}{M}}\operatorname{cos}(\mathbf{WK}^T + \mathbf{b})^T$. Combine these with $\mathbf{D_T}$: $\mathbf{Q’} = \mathbf{D_Q}\mathbf{\hat{Q}}$ and $\mathbf{K’} = \mathbf{D_K}\mathbf{\hat{K}}$. Therefore, \(\mathbf{A} = \mathbb{E}\left[\mathbf{Q'}\mathbf{K'}^T\right]\)

And $\mathbf{\hat{A}} = \mathbf{Q’}\mathbf{K’}^T$ is an unbiased estimator of $\mathbf{A}$. However, would still like to avoid computing and storing the full $L \times L$ attention matrix. We can do this when calculating approximate dot-product attention by being clever in how we group the matrix multiplications:

\[\operatorname{\hat{Att}}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathbf{\hat{D}}^{-1}\mathbf{\hat{A}V} = \mathbf{\hat{D}}^{-1}(\mathbf{Q'}(\mathbf{K'}^T\mathbf{V}))\]

with

\[\mathbf{\hat{D}}^{-1} = \operatorname{diag}(\mathbf{Q'}(\mathbf{K'}^T\mathbf{1}_L))\]

This requires $O(LMd)$ time and $O(Md + Ld + ML)$ space.

Linear self-attention

Instead of approximating the attention matrix $\mathbf{A}$, the second paper directly approximates the result of dot-product attention using linear attention. First, they prove that dot-product attention is low-rank, and then propose to replace the $L \times L$ attention matrix with a $L \times k$ approximation by using a learned weight matrix $\mathbf{E} \in \mathbb{R}^{k \times L}$ to project $\mathbf{K}$ into $\mathbb{R}^{k \times d}$:

\[\mathbf{\hat{A}} = \operatorname{exp}\left(\frac{\mathbf{Q}(\mathbf{EK})^T}{\sqrt{d}}\right)\]

Likewise, the values are also projected to $\mathbb{R}^{k \times d}$, and

\[\operatorname{\hat{Att}}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathbf{\hat{D}}^{-1}\mathbf{\hat{A}FV}\]

with $\mathbf{F} \in \mathbb{R}^{k \times L}$ and

\[\mathbf{\hat{D}}^{-1} = \operatorname{diag}(\mathbf{\hat{A}}\mathbf{1}_L)\]

This requires $O(Lk)$ time. In practice, the authors find that $k \geq 256$ seems to perform well.

Discussion

Some random thoughts on these papers.

  • While both of these approximations perform comparably well to dot-product attention in the experiments presented in the papers, those experiments seem pretty perfunctory to me. I’d be more impressed if they demonstrated for a real problem that using one of these approximations opens up new possibilities. Concatenating random proteins together to length 8096 isn’t a real problem!
  • The FAVOR paper presents a general kernel framework for attention and investigate the performance of some simple kernels in their experiments. It’d be fun to try and design attention kernels for specific problem domains.
  • FAVOR makes a big deal about how they have an unbiased estimator for $\mathbf{A}$, but I’m pretty sure their $\operatorname{\hat{Att}}$ is not an unbiased estimator of dot-product attention.
  • The linear projections in linear self-attention ($\mathbf{E}$ and $\mathbf{F}$) are dependent on the sequence length, which would make dealing with variable-length inputs messy.