Flash Attention
#optimization #attention #transformers
An Efficient Attention Process
Standard Attention
We have the
- Load
blocks from HBM to the GPU's SRAM. - Compute
in GPU's SRAM and then move it to HBM. - Move
from HBM and then , write to HBM. - Load
and by blocks from HBM, compute and then move to HBM.
Flash Attention
Goal - Minimize the number of HBM accesses
To compute numerically stable softmax ->
-> Find the max value from the input vector -> before exponentiating each term, make sure to subtract the max from it. -> The running sum for the denominator.
We would like to have a way to compute softmax online and not after we calculate all the input terms. If we are able to do that, we would be able to compute the softmax without moving
Let