Update to the comments in reference_hstu_attention_bwd.hpp

This commit is contained in:
Qianfeng Zhang
2026-06-09 08:14:07 +00:00
parent 08873f0d50
commit 62627db768

View File

@@ -28,8 +28,11 @@ namespace ck_tile {
// LSE[sq] = log(sum_sk exp(S[sq,sk])) (kUseSoftmax=true, saved during fwd)
//
// Backward derivation:
// dV[sk,k] = sum_sq P[sq,sk] * dO[sq,k] -- GEMM: P^T @ dO
// dP[sq,sk] = sum_k dO[sq,k] * V[sk,k] -- GEMM: dO @ V^T
// dV[sk,k] = sum_sq P[sq,sk] * dO[sq,k]
// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
//
// dP[sq,sk] = sum_k dO[sq,k] * V[sk,k]
// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
//
// kUseSoftmax=false (SiLU path):
// dsilu(x) = sigmoid(x) * (1 + x*(1 - sigmoid(x)))
@@ -50,8 +53,11 @@ namespace ck_tile {
// dS[sq,sk] = P[sq,sk] * (dP[sq,sk] - D[sq])
// (masked-out positions have P=0, so they contribute 0 naturally)
//
// dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k] -- GEMM: dS @ K
// dK[sk,k] = alpha * sum_sq dS[sq,sk] * Q[sq,k] -- GEMM: dS^T @ Q
// dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k]
// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
//
// dK[sk,k] = alpha * sum_sq dS[sq,sk] * Q[sq,k]
// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
// clang-format on
template <typename InOutDataType,
@@ -235,7 +241,8 @@ struct reference_no_group_hstu_attention_bwd
{
// ------------------------------------------------------------------
// Step 1: Recompute S[sq,:] and P[sq,:] (forward pass recomputation)
// S[sq,sk] = alpha * Q[sq] . K[sk]^T (masked-in)
// S[sq,sk] = alpha * Q[sq] . K[sk] (masked-in)
// S = alpha * Q @ K (A=Q[sq,hdim_qk], B=K[sk,hdim_qk])
// P[sq,sk] = silu(S)*scale_p or softmax_row(S)
// ------------------------------------------------------------------
std::vector<CompDataType> locals_S(seqlen_kv);
@@ -311,8 +318,8 @@ struct reference_no_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k]
// dV = P^T @ dO (accumulates over sq)
// Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k] (over sq)
// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
// ------------------------------------------------------------------
for(int sk = 0; sk < seqlen_kv; sk++)
{
@@ -333,7 +340,7 @@ struct reference_no_group_hstu_attention_bwd
// ------------------------------------------------------------------
// Step 3: Compute dP[sq,sk] = dO[sq,:] . V[sk,:]
// dP = dO @ V^T
// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
// ------------------------------------------------------------------
std::vector<CompDataType> locals_dP(seqlen_kv);
for(int sk = 0; sk < seqlen_kv; sk++)
@@ -414,8 +421,9 @@ struct reference_no_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 5: Compute dQ[sq,k] = alpha * dS[sq,:] . K
// dQ = alpha * dS @ K (computed fresh for each sq, no accumulation)
// Step 5: Compute dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k]
// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
// (computed fresh per sq row, no accumulation needed)
// ------------------------------------------------------------------
for(int k = 0; k < hdim_qk; k++)
{
@@ -443,8 +451,8 @@ struct reference_no_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k]
// dK = alpha * dS^T @ Q (accumulates over sq)
// Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k] (over sq)
// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
// ------------------------------------------------------------------
for(int sk = 0; sk < seqlen_kv; sk++)
{
@@ -660,7 +668,8 @@ struct reference_group_hstu_attention_bwd
{
// ------------------------------------------------------------------
// Step 1: Recompute S[sq,:] and P[sq,:] (forward pass recomputation)
// S[sq,sk] = alpha * Q[sq] . K[sk]^T (masked-in)
// S[sq,sk] = alpha * Q[sq] . K[sk] (masked-in)
// S = alpha * Q @ K (A=Q[sq,hdim_qk], B=K[sk,hdim_qk])
// P[sq,sk] = silu(S)*scale_p or softmax_row(S)
// ------------------------------------------------------------------
std::vector<CompDataType> locals_S(seqlen_kv);
@@ -721,8 +730,8 @@ struct reference_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k]
// dV = P^T @ dO (accumulates over sq)
// Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k] (over sq)
// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
// ------------------------------------------------------------------
for(int sk = 0; sk < seqlen_kv; sk++)
{
@@ -738,7 +747,7 @@ struct reference_group_hstu_attention_bwd
// ------------------------------------------------------------------
// Step 3: Compute dP[sq,sk] = dO[sq,:] . V[sk,:]
// dP = dO @ V^T
// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
// ------------------------------------------------------------------
std::vector<CompDataType> locals_dP(seqlen_kv);
for(int sk = 0; sk < seqlen_kv; sk++)
@@ -799,8 +808,9 @@ struct reference_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 5: Compute dQ[sq,k] = alpha * dS[sq,:] . K
// dQ = alpha * dS @ K (computed fresh for each sq, no accumulation)
// Step 5: Compute dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k]
// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
// (computed fresh per sq row, no accumulation needed)
// ------------------------------------------------------------------
for(int k = 0; k < hdim_qk; k++)
{
@@ -819,8 +829,8 @@ struct reference_group_hstu_attention_bwd
}
// ------------------------------------------------------------------
// Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k]
// dK = alpha * dS^T @ Q (accumulates over sq)
// Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k] (over sq)
// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
// ------------------------------------------------------------------
for(int sk = 0; sk < seqlen_kv; sk++)
{