Update to the comments in reference_hstu_attention_bwd.hpp

This commit is contained in:
Qianfeng Zhang
2026-06-12 15:31:16 +00:00
parent 7d2e575fed
commit 7a67ae4dd3

View File

@@ -21,18 +21,16 @@ namespace ck_tile {
// Given dO, Q, K, V, LSE (and the same mask parameters as the forward), compute dQ, dK, dV.
//
// Forward recap (see reference_hstu_attention_fwd.hpp):
// S[sq,sk] = alpha * Q[sq] . K[sk]^T (masked-in pairs), else 0 or -inf
// S = alpha * Q @ K (A=Q[sq,hdim_qk], B=K[sk,hdim_qk]), masked-in pairs, else 0 or -inf
// P[sq,sk] = silu(S[sq,sk]) * scale_p (kUseSoftmax=false)
// = softmax_row(S)[sq,sk] (kUseSoftmax=true)
// O[sq,k] = sum_sk P[sq,sk] * V[sk,k]
// O = P @ V^T (A=P[sq,sk], B=V^T[hdim_v,sk])
// LSE[sq] = log(sum_sk exp(S[sq,sk])) (kUseSoftmax=true, saved during fwd)
//
// Backward derivation:
// dV[sk,k] = sum_sq P[sq,sk] * dO[sq,k]
// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
//
// dP[sq,sk] = sum_k dO[sq,k] * V[sk,k]
// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
//
// kUseSoftmax=false (SiLU path):
// dsilu(x) = sigmoid(x) * (1 + x*(1 - sigmoid(x)))
@@ -53,11 +51,9 @@ namespace ck_tile {
// dS[sq,sk] = P[sq,sk] * (dP[sq,sk] - D[sq])
// (masked-out positions have P=0, so they contribute 0 naturally)
//
// dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k]
// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
//
// dK[sk,k] = alpha * sum_sq dS[sq,sk] * Q[sq,k]
// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
// clang-format on
template <typename InOutDataType,
@@ -241,7 +237,6 @@ struct reference_no_group_hstu_attention_bwd
{
// ------------------------------------------------------------------
// Step 1: Recompute S[sq,:] and P[sq,:] (forward pass recomputation)
// S[sq,sk] = alpha * Q[sq] . K[sk] (masked-in)
// S = alpha * Q @ K (A=Q[sq,hdim_qk], B=K[sk,hdim_qk])
// P[sq,sk] = silu(S)*scale_p or softmax_row(S)
// ------------------------------------------------------------------
@@ -318,8 +313,7 @@ struct reference_no_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k] (over sq)
// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
// Step 2: dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
// ------------------------------------------------------------------
for(int sk = 0; sk < seqlen_kv; sk++)
{
@@ -339,8 +333,7 @@ struct reference_no_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 3: Compute dP[sq,sk] = dO[sq,:] . V[sk,:]
// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
// Step 3: dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
// ------------------------------------------------------------------
std::vector<CompDataType> locals_dP(seqlen_kv);
for(int sk = 0; sk < seqlen_kv; sk++)
@@ -376,7 +369,7 @@ struct reference_no_group_hstu_attention_bwd
// = 0 (masked-out)
//
// kUseSoftmax=true (Softmax):
// D[sq] = dO[sq] . O[sq] (uses forward output O directly)
// D[sq] = dO[sq] row(.) O[sq] (uses forward output O directly)
// dS[sq,sk] = P[sq,sk] * (dP[sq,sk] - D[sq])
// ------------------------------------------------------------------
std::vector<CompDataType> locals_dS(seqlen_kv);
@@ -394,7 +387,7 @@ struct reference_no_group_hstu_attention_bwd
}
else
{
// D[sq] = dO[sq] . O[sq]
// D[sq] = dO[sq] row(.) O[sq]
GemmAccDataType D_acc = 0.f;
for(int k = 0; k < hdim_v; k++)
{
@@ -421,8 +414,7 @@ struct reference_no_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 5: Compute dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k]
// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
// Step 5: dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
// (computed fresh per sq row, no accumulation needed)
// ------------------------------------------------------------------
for(int k = 0; k < hdim_qk; k++)
@@ -451,8 +443,7 @@ struct reference_no_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k] (over sq)
// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
// Step 6: dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
// ------------------------------------------------------------------
for(int sk = 0; sk < seqlen_kv; sk++)
{
@@ -668,7 +659,6 @@ struct reference_group_hstu_attention_bwd
{
// ------------------------------------------------------------------
// Step 1: Recompute S[sq,:] and P[sq,:] (forward pass recomputation)
// S[sq,sk] = alpha * Q[sq] . K[sk] (masked-in)
// S = alpha * Q @ K (A=Q[sq,hdim_qk], B=K[sk,hdim_qk])
// P[sq,sk] = silu(S)*scale_p or softmax_row(S)
// ------------------------------------------------------------------
@@ -730,8 +720,7 @@ struct reference_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k] (over sq)
// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
// Step 2: dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
// ------------------------------------------------------------------
for(int sk = 0; sk < seqlen_kv; sk++)
{
@@ -746,8 +735,7 @@ struct reference_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 3: Compute dP[sq,sk] = dO[sq,:] . V[sk,:]
// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
// Step 3: dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
// ------------------------------------------------------------------
std::vector<CompDataType> locals_dP(seqlen_kv);
for(int sk = 0; sk < seqlen_kv; sk++)
@@ -773,7 +761,7 @@ struct reference_group_hstu_attention_bwd
// = 0 (masked-out)
//
// kUseSoftmax=true (Softmax):
// D[sq] = dO[sq] . O[sq] (uses forward output O directly)
// D[sq] = dO[sq] row(.) O[sq] (uses forward output O directly)
// dS[sq,sk] = P[sq,sk] * (dP[sq,sk] - D[sq])
// ------------------------------------------------------------------
std::vector<CompDataType> locals_dS(seqlen_kv);
@@ -791,7 +779,7 @@ struct reference_group_hstu_attention_bwd
}
else
{
// D[sq] = dO[sq] . O[sq]
// D[sq] = dO[sq] row(.) O[sq]
GemmAccDataType D_acc = 0.f;
for(int k = 0; k < hdim_v; k++)
{
@@ -808,8 +796,7 @@ struct reference_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 5: Compute dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k]
// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
// Step 5: dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
// (computed fresh per sq row, no accumulation needed)
// ------------------------------------------------------------------
for(int k = 0; k < hdim_qk; k++)
@@ -829,8 +816,7 @@ struct reference_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k] (over sq)
// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
// Step 6: dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
// ------------------------------------------------------------------
for(int sk = 0; sk < seqlen_kv; sk++)
{