Flash Attention

#optimization #attention #transformers
An Efficient Attention Process

Standard Attention

We have the Q,K,VRNd matrices in the HBM

  1. Load Q,K blocks from HBM to the GPU's SRAM.
  2. Compute S=QKT in GPU's SRAM and then move it to HBM.
  3. Move S from HBM and then P=softmax(S), write P to HBM.
  4. Load P and V by blocks from HBM, compute O=PV and then move O to HBM.

Flash Attention

Goal - Minimize the number of HBM accesses

To compute numerically stable softmax ->

  1. m(x)=maxixi -> Find the max value from the input vector
  2. f(x)=[ex1m(x),,exBm(x)] -> before exponentiating each term, make sure to subtract the max from it.
  3. l(x)=f(x)i -> The running sum for the denominator.

We would like to have a way to compute softmax online and not after we calculate all the input terms. If we are able to do that, we would be able to compute the softmax without moving S back and forth from HBM.

Let x=[x1,x2], then we could write f(x) and l(x) as follows:

f(x)=[em(x1)m(x)f(x1),em(x2)m(x)f(x2)]