建议补全前置知识: GPU Arch:自顶向下分析
Transformer 作为GPT类模型的基础架构提供了强大的特征处理能力,但是处理更长上下文仍然是一个挑战,因为核心的自注意力模块在序列长度上具有O(N^2)的时间和内存复杂度。 😓
这篇Flash Attention的工作深入硬件,新提出了一种具有IO感知的,快速的⚡️,节省内存的🧠,精确的🎯注意力算法。目前,Flash Attention已经集成至torch2.0,并且社区也提供了多种实现,接下来我们以Triton实现为例简单介绍一下这篇工作,
核心要点
-
⚡️为什么加快了计算?Fast
- 降低了耗时的HBM访问次数。采用Tiling技术分块从HBM加载数据到SRAM进行融合计算。
-
🧠为什么节省了内存?Memory-Efficient
- 不再对中间矩阵S,P进行存储。在反向的时候通过Recomputation重新计算来计算梯度。
-
🎯为什么是精准注意力?Exact Attention
- 算法流程只是分块计算,无近似操作。
提出问题
Transformer 结构已成为自然语言处理和图像分类等应用中最常用的架构。尽管 Transformer 在规模上不断增大和加深,但处理更长上下文仍然是一个挑战,因为核心的自注意力模块在序列长度上具有二次方的时间和内存复杂度。这导致在处理长序列时速度变慢且内存需求巨大。因此,我们需要一些优化算法来提高注意力模块的计算速度和内存利用率。
解决方案
Bili视频演示: ⏱️78s看懂FlashAttention【有点意思·1】_哔哩哔哩_bilibili
ManimCode: github.com/cauyxy/bili…
Forward
Standard Attention Implementation
在注意力的一般实现中,对Q,K,V∈RN×dmathbf{Q}, mathbf{K}, mathbf{V} in mathbb{R}^{N times d}三个输入执行以下算法得到输出Omathbf{O},其中softmax行级别执行。
S=QK⊤∈RN×N,P=softmax(S)∈RN×N,O=PV∈RN×d,begin{equation} mathbf{S}=mathbf{Q K}^{top} in mathbb{R}^{N times N}, quad mathbf{P}=operatorname{softmax}(mathbf{S}) in mathbb{R}^{N times N}, quad mathbf{O}=mathbf{P} mathbf{V} in mathbb{R}^{N times d}, end{equation}
在这个算法中,S,Pmathbf{S}, mathbf{P} 矩阵都是很大,需要在HBM中实例化来进行存储,这样就会带来很多HBM的访问次数,最终体现到算法时间端到端较长的延迟。
FlashAttention Implementation(Tiling)
理论基础
在传统算法中,一种方式是将Mask和SoftMax部分融合,以减少访存次数。然而,FlashAttention则更加激进,它将从输入Qmathbf{Q}, Kmathbf{K}, Vmathbf{V}到输出Omathbf{O}的整个过程进行融合,以避免 S,Pmathbf{S}, mathbf{P} 矩阵的存储开销,实现端到端的延迟缩减。然而,由于输入的长度NN通常很长,无法完全将完整的 Q,K,V,Omathbf{Q}, mathbf{K}, mathbf{V},mathbf{O} 及中间计算结果存储在SRAM中。因此,需要依赖HBM进行访存操作,与原始计算延迟相比没有太大差异,甚至会变慢(没具体测)。
为了让计算过程的结果完全在SRAM中,摆脱对HBM的依赖,可以采用分片操作,每次进行部分计算,确保这些计算结果能在SRAM内进行交互,待得到对应的结果后再进行输出。
这个过程中,有一点需要注意的是,之前对于softmax的计算是以行为单位的,如下所示:
m(x):=maxixi,f(x):=[ex1−m(x)…exB−m(x)],ℓ(x):=∑if(x)i,softmax(x):=f(x)ℓ(x)begin{equation} m(x):=max _i x_i, quad f(x):=left[begin{array}{lll} e^{x_1-m(x)} & ldots & e^{x_B-m(x)} end{array}right], quad ell(x):=sum_i f(x)_i, quad operatorname{softmax}(x):=frac{f(x)}{ell(x)} end{equation}
当我们将输入进行分片后,无法对完整的行数据执行Softmax操作。这是因为Softmax函数在计算时需要考虑整个行的数据。然而,我们可以通过如下所示方法来获得与完整行Softmax相同的结果,而无需使用近似操作。
m(x)=m([x(1)x(2)])=max(m(x(1)),m(x(2))),f(x)=[em(x(1))−m(x)f(x(1))em(x(2))−m(x)f(x(2))],ℓ(x)=ℓ([x(1)x(2)])=em(x(1))−m(x)ℓ(x(1))+em(x(2))−m(x)ℓ(x(2)),softmax(x)=f(x)ℓ(x).begin{equation} begin{aligned} & m(x)=mleft(left[x^{(1)} x^{(2)}right]right)=max left(mleft(x^{(1)}right), mleft(x^{(2)}right)right), quad f(x)=left[begin{array}{ll} e^{mleft(x^{(1)}right)-m(x)} fleft(x^{(1)}right) & e^{mleft(x^{(2)}right)-m(x)} fleft(x^{(2)}right) end{array}right], \ & ell(x)=ellleft(left[x^{(1)} x^{(2)}right]right)=e^{mleft(x^{(1)}right)-m(x)} ellleft(x^{(1)}right)+e^{mleft(x^{(2)}right)-m(x)} ellleft(x^{(2)}right), quad operatorname{softmax}(x)=frac{f(x)}{ell(x)} . end{aligned} end{equation}
具体的分块softmax代码演示:github.com/cauyxy/bili…
代码实现
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
# -- compute qk ----
k = tl.load(k_ptrs)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
l_prev *= tl.exp(m_prev - m_curr)
# attention weights
p = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_rcp = 1. / l_curr
p *= l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
# update acc
p = p.to(Q.dtype.element_ty)
v = tl.load(v_ptrs)
acc += tl.dot(p, v)
# update m_i and l_i
l_prev = l_curr
m_prev = m_curr
# update pointers
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_prev)
tl.store(m_ptrs, m_prev)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
IO Complexity Analysis
Standard Attention
对于标准注意力实现,初期我们需要把输入Q,K,Vmathbf{Q}, mathbf{K}, mathbf{V}从HBM中读取,并计算完毕后把输出Omathbf{O}写入到HBM中。
第一步把Q,Kmathbf{Q}, mathbf{K}读取出来计算出S=QK⊤mathbf{S}=mathbf{Q K}^{top},然后把Smathbf{S}存回去,内存访问复杂度Θ(Nd+N2)Thetaleft(N d+N^2right)。
第二步把Smathbf{S}读取出来计算出P=softmax(Smathbf{P}=operatorname{softmax}(mathbf{S},然后把Pmathbf{P}存回去,内存访问复杂度Θ(N2)Thetaleft(N^2right)。
第三步把V,Pmathbf{V}, mathbf{P}读取出来计算出O=PVmathbf{O}=mathbf{P} mathbf{V},然后计算出结果Omathbf{O},内存访问复杂度Θ(Nd+N2)Thetaleft(N d+N^2right)。
综上所述,整体的内存访问复杂度为Θ(Nd+N2)Thetaleft(N d+N^2right)。
FlashAttention
对于FlashAttention,我们设置一个分块大小BcB_c来把K,Vmathbf{K}, mathbf{V}分成TcT_c块,对于Q,Omathbf{Q}, mathbf{O}的每一块都要把K,Vmathbf{K}, mathbf{V }部分的全部元素Load一遍,这样则有FlashAttention的内存访问复杂度为Θ(Nd+NdTc)Thetaleft(N d+N d T_cright)=Θ(NdTc)Thetaleft(N d T_cright).
在这里,我们需要两个分块大小,Q,Omathbf{Q}, mathbf{O}的分块大小BrB_r,K,Vmathbf{K}, mathbf{V}的分块大小BcB_c,我们设定SRAM的大小为MM,为了能把分块后的K,V ∈RBc×dmathbf{K}, mathbf{V} in mathbb{R}^{B_c times d}放进SRAM,那么则有一下限制:
Bcd=O(M)⇔Bc=O(Md)begin{equation} B_c d=O(M) Leftrightarrow B_c=Oleft(frac{M}{d}right) end{equation}
相应的,Q,O ∈RBr×dmathbf{Q}, mathbf{O} in mathbb{R}^{B_r times d}有如下限制:
Brd=O(M)⇔Br=O(Md)begin{equation} B_r d=O(M) Leftrightarrow B_r=Oleft(frac{M}{d}right) end{equation}
最终,还有一个中间态S=QK⊤ ∈RBr×Bcmathbf{S}=mathbf{Q K}^{top} in mathbb{R}^{B_r times B_c}需要存储,则有如下限制:
BrBc=O(M)begin{equation} B_r B_c=O(M) end{equation}
综上,限制如下
Bc=Θ(Md),Br=Θ(min(Md,MBc))=Θ(min(Md,d))begin{equation} B_c=Thetaleft(frac{M}{d}right), quad B_r=Thetaleft(min left(frac{M}{d}, frac{M}{B_c}right)right)=Thetaleft(min left(frac{M}{d}, dright)right) end{equation}
进而推出
Tc=NBc=Θ(NdM)begin{equation} T_c=frac{N}{B_c}=Thetaleft(frac{N d}{M}right) end{equation}
那么在M=Θ(NdM = Theta (Nd的前提下,则有FlashAttention的HBM内存访问复杂度为:
Θ(NdTc)=Θ(N2d2M) = Θ(Nd)begin{equation} Thetaleft(N d T_cright)=Thetaleft(frac{N^2 d^2}{M}right) = Thetaleft({N d}right) end{equation}
在语言建模中,通常有d ⋘ Nd lll N,则有Θstand(Nd+N2) > Θflash(Nd)Theta_{stand} left(N d+N^2right) > Theta_{flash} left(N dright)。这样,在前向的过程中,我们采用分块计算的方式,避免了S,Pmathbf{S}, mathbf{P}矩阵的存储开销,整体的运算都在SRAM内进行,降低了HBM访问次数,大大提升了计算的速度,减少了对存储的消耗。
Backward
理论基础
在上面前向的时候我们为了减少HBM访存次数,降低内存消耗量,我们并没有对S,Pmathbf{S}, mathbf{P}矩阵进行存储,而这个在反向传播计算梯度的时候确实需要的一个信息。之前有通过Gradient checkpointing的方式来实现梯度实现在前向的时候更加节省内存。
我们这里则采用重新计算的方式来计算对应的梯度。在上面前向计算的时候我们不会存储S,Pmathbf{S}, mathbf{P}矩阵,但是我们会存储对应的指数项之和LL来进行梯度的计算。
我们在反向的过程中最重要的事情就是就是Loss函数ϕphi对Q,K,V,Omathbf{Q}, mathbf{K}, mathbf{V}, mathbf{O}对应的梯度。
Omathbf{O}对应的梯度最好计算dO=∂ϕ∂Omathbf{dO} = frac{partial phi}{partial mathbf{O}},其中Omathbf{O}是现成的。
Vmathbf{V}对应的梯度也很好计算,由于O = PVmathbf{O} = mathbf{P}mathbf{V} ,根据链式求导法则和矩阵求导法则则有dV=PTdOmathbf{d V}=mathbf{P}^T mathbf{d} mathbf{O},更详细如下所示:
dvj=∑iPijdoi=∑ieqiTkjLidoibegin{equation} d v_j=sum_i P_{i j} d o_i=sum_i frac{e^{q_i^T k_j}}{L_i} d o_i end{equation}
Q, Kmathbf{Q}, mathbf{K}对应的梯度算起来就比较复杂一点。这两个经过的计算逻辑步骤更多,我们可以一步一步的来进行计算。我们可以先计算dP, dSmathbf{dP}, mathbf{dS}。由于O = PVmathbf{O} = mathbf{P}mathbf{V} ,则有dPmathbf{dP}如下表示
dPij=doiTvjbegin{equation} d P_{i j}=d o_i^T v_j end{equation}
Fact: y=softmax(x)y=operatorname{softmax}left(xright)的雅各比矩阵为diag(y)−yyToperatorname{diag}(y)-y y^T,具体推导见 Derivative of the Softmax Function and the Categorical Cross-Entropy Loss
由于Pi:=softmax(Si:)P_{i:}=operatorname{softmax}left(S_{i:}right),根据上述定理则有:
dSi:=(diag(Pi:)−Pi:Pi:T)dPi:=Pi:∘dPi:−(Pi:TdPi:)Pi:begin{equation} d S_{i:}=left(operatorname{diag}left(P_{i:}right)-P_{i:} P_{i:}^Tright) d P_{i:}=P_{i:} circ d P_{i:}-left(P_{i:}^T d P_{i:}right) P_{i:} end{equation}
接下来我们定义如下表示:
Di=Pi:TdPi:=∑eqiκjLidoiTvj=doiT∑eqiκjLivj=doiToibegin{equation} D_i=P_{i:}^T d P_{i:}=sum frac{e^{q_i kappa_j}}{L_i} d o_i^T v_j=d o_i^T sum frac{e^{q_i kappa_j}}{L_i} v_j=d o_i^T o_i end{equation}
根据上述定义简化12式则有如下表示:
dSi:=Pi:∘dPi:−DiPi:begin{equation} d S_{i:}=P_{i:} circ d P_{i:}-D_i P_{i:} end{equation}
相应的dSmathbf{dS}可表示为如下形式:
dSij=PijdPij−DiPij=Pij(dPij−Di)begin{equation} d S_{i j}=P_{i j} d P_{i j}-D_i P_{i j}=P_{i j}left(d P_{i j}-D_iright) end{equation}
又因为Sij=qiTkjS_{i j}=q_i^T k_j,结合上述推导利用链式求导法则Q, Kmathbf{Q}, mathbf{K}对应的梯度有如下表示:
dqi=∑jdSijkj=∑jPij(dPij−Di)kj=∑jeqiTkjLi(doiTvj−Di)kjbegin{equation} d q_i=sum_j d S_{i j} k_j=sum_j P_{i j}left(d P_{i j}-D_iright) k_j=sum_j frac{e^{q_i^T k_j}}{L_i}left(d o_i^T v_j-D_iright) k_j end{equation}
dkj=∑idSijqi=∑iPij(dPij−Di)qi=∑ieqiTkjLi(doiTvj−Di)qibegin{equation} d k_j=sum_i d S_{i j} q_i=sum_i P_{i j}left(d P_{i j}-D_iright) q_i=sum_i frac{e^{q_i^T k_j}}{L_i}left(d o_i^T v_j-D_iright) q_i end{equation}
至此,我们得到了一个完整的包含前向和反向的,降低了HBM访问次数的,新的Attention算子。
代码实现
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
Block-Sparse
相比于上面的全量计算,块稀疏的FlashAttention需要额外提供一个Mask矩阵M~∈{0,1}N×Ntilde{mathbf{M}} in{0,1}^{N times N}用于将一些元素置零来保证块稀疏加速计算。本文对于块稀疏的一个计算只是一个简单的尝试,没有进行太深入的探索,所以这里我们先一笔带过,后面我们可以讲一篇对FlashAttention进行块稀疏优化的工作SCFA.
S=QK⊤∈RN×N,P=softmax(S⊙1M~)∈RN×N,O=PV∈RN×dbegin{equation} mathbf{S}=mathbf{Q} mathbf{K}^{top} in mathbb{R}^{N times N}, quad mathbf{P}=operatorname{softmax}left(mathbf{S} odot mathbb{1}_{tilde{mathbf{M}}}right) in mathbb{R}^{N times N}, quad mathbf{O}=mathbf{P V} in mathbb{R}^{N times d} end{equation}
实验验证
通过实验验证发现,FlashAttention在速度和内存占用方面都表现出明显的优势,并取得了良好的效果。目前,FlashAttention已经经过广泛验证, torch2.0中已提供flashattention的实现。正如标题《Fast and Memory-Efficient Exact Attention with IO-Awareness》所示,FlashAttention的优点在于充分考虑了在计算任务中IO的重要性,并通过分块计算的方式开发了一种快速、节省显存、精确无近似的注意力实现方法。这使得我们更便于训练具有更长上下文的Transformer模型,并且为后续注意力算法的优化提供了一个基准。