diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention_bwd.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention_bwd.hpp index bdabba0f6b..d70a530a16 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention_bwd.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention_bwd.hpp @@ -28,8 +28,11 @@ namespace ck_tile { // LSE[sq] = log(sum_sk exp(S[sq,sk])) (kUseSoftmax=true, saved during fwd) // // Backward derivation: -// dV[sk,k] = sum_sq P[sq,sk] * dO[sq,k] -- GEMM: P^T @ dO -// dP[sq,sk] = sum_k dO[sq,k] * V[sk,k] -- GEMM: dO @ V^T +// dV[sk,k] = sum_sq P[sq,sk] * dO[sq,k] +// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq]) +// +// dP[sq,sk] = sum_k dO[sq,k] * V[sk,k] +// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v]) // // kUseSoftmax=false (SiLU path): // dsilu(x) = sigmoid(x) * (1 + x*(1 - sigmoid(x))) @@ -50,8 +53,11 @@ namespace ck_tile { // dS[sq,sk] = P[sq,sk] * (dP[sq,sk] - D[sq]) // (masked-out positions have P=0, so they contribute 0 naturally) // -// dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k] -- GEMM: dS @ K -// dK[sk,k] = alpha * sum_sq dS[sq,sk] * Q[sq,k] -- GEMM: dS^T @ Q +// dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k] +// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk]) +// +// dK[sk,k] = alpha * sum_sq dS[sq,sk] * Q[sq,k] +// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq]) // clang-format on template locals_S(seqlen_kv); @@ -311,8 +318,8 @@ struct reference_no_group_hstu_attention_bwd } // ------------------------------------------------------------------ - // Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k] - // dV = P^T @ dO (accumulates over sq) + // Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k] (over sq) + // dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq]) // ------------------------------------------------------------------ for(int sk = 0; sk < seqlen_kv; sk++) { @@ -333,7 +340,7 @@ struct reference_no_group_hstu_attention_bwd // ------------------------------------------------------------------ // Step 3: Compute dP[sq,sk] = dO[sq,:] . V[sk,:] - // dP = dO @ V^T + // dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v]) // ------------------------------------------------------------------ std::vector locals_dP(seqlen_kv); for(int sk = 0; sk < seqlen_kv; sk++) @@ -414,8 +421,9 @@ struct reference_no_group_hstu_attention_bwd } // ------------------------------------------------------------------ - // Step 5: Compute dQ[sq,k] = alpha * dS[sq,:] . K - // dQ = alpha * dS @ K (computed fresh for each sq, no accumulation) + // Step 5: Compute dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k] + // dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk]) + // (computed fresh per sq row, no accumulation needed) // ------------------------------------------------------------------ for(int k = 0; k < hdim_qk; k++) { @@ -443,8 +451,8 @@ struct reference_no_group_hstu_attention_bwd } // ------------------------------------------------------------------ - // Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k] - // dK = alpha * dS^T @ Q (accumulates over sq) + // Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k] (over sq) + // dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq]) // ------------------------------------------------------------------ for(int sk = 0; sk < seqlen_kv; sk++) { @@ -660,7 +668,8 @@ struct reference_group_hstu_attention_bwd { // ------------------------------------------------------------------ // Step 1: Recompute S[sq,:] and P[sq,:] (forward pass recomputation) - // S[sq,sk] = alpha * Q[sq] . K[sk]^T (masked-in) + // S[sq,sk] = alpha * Q[sq] . K[sk] (masked-in) + // S = alpha * Q @ K (A=Q[sq,hdim_qk], B=K[sk,hdim_qk]) // P[sq,sk] = silu(S)*scale_p or softmax_row(S) // ------------------------------------------------------------------ std::vector locals_S(seqlen_kv); @@ -721,8 +730,8 @@ struct reference_group_hstu_attention_bwd } // ------------------------------------------------------------------ - // Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k] - // dV = P^T @ dO (accumulates over sq) + // Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k] (over sq) + // dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq]) // ------------------------------------------------------------------ for(int sk = 0; sk < seqlen_kv; sk++) { @@ -738,7 +747,7 @@ struct reference_group_hstu_attention_bwd // ------------------------------------------------------------------ // Step 3: Compute dP[sq,sk] = dO[sq,:] . V[sk,:] - // dP = dO @ V^T + // dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v]) // ------------------------------------------------------------------ std::vector locals_dP(seqlen_kv); for(int sk = 0; sk < seqlen_kv; sk++) @@ -799,8 +808,9 @@ struct reference_group_hstu_attention_bwd } // ------------------------------------------------------------------ - // Step 5: Compute dQ[sq,k] = alpha * dS[sq,:] . K - // dQ = alpha * dS @ K (computed fresh for each sq, no accumulation) + // Step 5: Compute dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k] + // dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk]) + // (computed fresh per sq row, no accumulation needed) // ------------------------------------------------------------------ for(int k = 0; k < hdim_qk; k++) { @@ -819,8 +829,8 @@ struct reference_group_hstu_attention_bwd } // ------------------------------------------------------------------ - // Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k] - // dK = alpha * dS^T @ Q (accumulates over sq) + // Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k] (over sq) + // dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq]) // ------------------------------------------------------------------ for(int sk = 0; sk < seqlen_kv; sk++) {