Part of Advances in Neural Information Processing Systems 36 (NeurIPS 2023) Main Conference Track
Josh Alman, Zhao Song
In modern machine learning, inner product attention computation is a fundamental task for training large language models such as Transformer, GPT-1, BERT, GPT-2, GPT-3 and ChatGPT. Formally, in this problem, one is given as input three matrices Q,K,V∈[−B,B]n×d, and the goal is to construct the matrix Att(Q,K,V):=diag(A1n)−1AV∈Rn×d, where A=exp(QK⊤/d) is the `attention matrix', and exp is applied entry-wise. Straightforward methods for this problem explicitly compute the n×n attention matrix A, and hence require time Ω(n2) even when d=no(1) is small. In this paper, we investigate whether faster algorithms are possible by \emph{implicitly} making use of the matrix A. We present two results, showing that there is a sharp transition at B=Θ(√logn).∙ If d=O(logn) and B=o(√logn), there is an n1+o(1) time algorithm to approximate Att(Q,K,V) up to 1/poly(n) additive error.∙ If d=O(logn) and B=Θ(√logn), assuming the Strong Exponential Time Hypothesis from fine-grained complexity theory, it is impossible to approximate Att(Q,K,V) up to 1/poly(n) additive error in truly subquadratic time n2−Ω(1).This gives a theoretical explanation for the phenomenon observed in practice that attention computation is much more efficient when the input matrices have smaller entries.