mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
Update to the comments in reference_hstu_attention_bwd.hpp
This commit is contained in:
@@ -21,18 +21,16 @@ namespace ck_tile {
|
||||
// Given dO, Q, K, V, LSE (and the same mask parameters as the forward), compute dQ, dK, dV.
|
||||
//
|
||||
// Forward recap (see reference_hstu_attention_fwd.hpp):
|
||||
// S[sq,sk] = alpha * Q[sq] . K[sk]^T (masked-in pairs), else 0 or -inf
|
||||
// S = alpha * Q @ K (A=Q[sq,hdim_qk], B=K[sk,hdim_qk]), masked-in pairs, else 0 or -inf
|
||||
// P[sq,sk] = silu(S[sq,sk]) * scale_p (kUseSoftmax=false)
|
||||
// = softmax_row(S)[sq,sk] (kUseSoftmax=true)
|
||||
// O[sq,k] = sum_sk P[sq,sk] * V[sk,k]
|
||||
// O = P @ V^T (A=P[sq,sk], B=V^T[hdim_v,sk])
|
||||
// LSE[sq] = log(sum_sk exp(S[sq,sk])) (kUseSoftmax=true, saved during fwd)
|
||||
//
|
||||
// Backward derivation:
|
||||
// dV[sk,k] = sum_sq P[sq,sk] * dO[sq,k]
|
||||
// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
|
||||
// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
|
||||
//
|
||||
// dP[sq,sk] = sum_k dO[sq,k] * V[sk,k]
|
||||
// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
|
||||
// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
|
||||
//
|
||||
// kUseSoftmax=false (SiLU path):
|
||||
// dsilu(x) = sigmoid(x) * (1 + x*(1 - sigmoid(x)))
|
||||
@@ -53,11 +51,9 @@ namespace ck_tile {
|
||||
// dS[sq,sk] = P[sq,sk] * (dP[sq,sk] - D[sq])
|
||||
// (masked-out positions have P=0, so they contribute 0 naturally)
|
||||
//
|
||||
// dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k]
|
||||
// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
|
||||
// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
|
||||
//
|
||||
// dK[sk,k] = alpha * sum_sq dS[sq,sk] * Q[sq,k]
|
||||
// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
|
||||
// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
|
||||
// clang-format on
|
||||
|
||||
template <typename InOutDataType,
|
||||
@@ -241,7 +237,6 @@ struct reference_no_group_hstu_attention_bwd
|
||||
{
|
||||
// ------------------------------------------------------------------
|
||||
// Step 1: Recompute S[sq,:] and P[sq,:] (forward pass recomputation)
|
||||
// S[sq,sk] = alpha * Q[sq] . K[sk] (masked-in)
|
||||
// S = alpha * Q @ K (A=Q[sq,hdim_qk], B=K[sk,hdim_qk])
|
||||
// P[sq,sk] = silu(S)*scale_p or softmax_row(S)
|
||||
// ------------------------------------------------------------------
|
||||
@@ -318,8 +313,7 @@ struct reference_no_group_hstu_attention_bwd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k] (over sq)
|
||||
// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
|
||||
// Step 2: dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
|
||||
// ------------------------------------------------------------------
|
||||
for(int sk = 0; sk < seqlen_kv; sk++)
|
||||
{
|
||||
@@ -339,8 +333,7 @@ struct reference_no_group_hstu_attention_bwd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Step 3: Compute dP[sq,sk] = dO[sq,:] . V[sk,:]
|
||||
// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
|
||||
// Step 3: dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
|
||||
// ------------------------------------------------------------------
|
||||
std::vector<CompDataType> locals_dP(seqlen_kv);
|
||||
for(int sk = 0; sk < seqlen_kv; sk++)
|
||||
@@ -376,7 +369,7 @@ struct reference_no_group_hstu_attention_bwd
|
||||
// = 0 (masked-out)
|
||||
//
|
||||
// kUseSoftmax=true (Softmax):
|
||||
// D[sq] = dO[sq] . O[sq] (uses forward output O directly)
|
||||
// D[sq] = dO[sq] row(.) O[sq] (uses forward output O directly)
|
||||
// dS[sq,sk] = P[sq,sk] * (dP[sq,sk] - D[sq])
|
||||
// ------------------------------------------------------------------
|
||||
std::vector<CompDataType> locals_dS(seqlen_kv);
|
||||
@@ -394,7 +387,7 @@ struct reference_no_group_hstu_attention_bwd
|
||||
}
|
||||
else
|
||||
{
|
||||
// D[sq] = dO[sq] . O[sq]
|
||||
// D[sq] = dO[sq] row(.) O[sq]
|
||||
GemmAccDataType D_acc = 0.f;
|
||||
for(int k = 0; k < hdim_v; k++)
|
||||
{
|
||||
@@ -421,8 +414,7 @@ struct reference_no_group_hstu_attention_bwd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Step 5: Compute dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k]
|
||||
// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
|
||||
// Step 5: dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
|
||||
// (computed fresh per sq row, no accumulation needed)
|
||||
// ------------------------------------------------------------------
|
||||
for(int k = 0; k < hdim_qk; k++)
|
||||
@@ -451,8 +443,7 @@ struct reference_no_group_hstu_attention_bwd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k] (over sq)
|
||||
// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
|
||||
// Step 6: dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
|
||||
// ------------------------------------------------------------------
|
||||
for(int sk = 0; sk < seqlen_kv; sk++)
|
||||
{
|
||||
@@ -668,7 +659,6 @@ struct reference_group_hstu_attention_bwd
|
||||
{
|
||||
// ------------------------------------------------------------------
|
||||
// Step 1: Recompute S[sq,:] and P[sq,:] (forward pass recomputation)
|
||||
// S[sq,sk] = alpha * Q[sq] . K[sk] (masked-in)
|
||||
// S = alpha * Q @ K (A=Q[sq,hdim_qk], B=K[sk,hdim_qk])
|
||||
// P[sq,sk] = silu(S)*scale_p or softmax_row(S)
|
||||
// ------------------------------------------------------------------
|
||||
@@ -730,8 +720,7 @@ struct reference_group_hstu_attention_bwd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Step 2: Accumulate dV[sk,k] += P[sq,sk] * dO[sq,k] (over sq)
|
||||
// dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
|
||||
// Step 2: dV = P^T @ dO^T (A=P^T[sk,sq], B=dO^T[hdim_v,sq])
|
||||
// ------------------------------------------------------------------
|
||||
for(int sk = 0; sk < seqlen_kv; sk++)
|
||||
{
|
||||
@@ -746,8 +735,7 @@ struct reference_group_hstu_attention_bwd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Step 3: Compute dP[sq,sk] = dO[sq,:] . V[sk,:]
|
||||
// dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
|
||||
// Step 3: dP = dO @ V (A=dO[sq,hdim_v], B=V[sk,hdim_v])
|
||||
// ------------------------------------------------------------------
|
||||
std::vector<CompDataType> locals_dP(seqlen_kv);
|
||||
for(int sk = 0; sk < seqlen_kv; sk++)
|
||||
@@ -773,7 +761,7 @@ struct reference_group_hstu_attention_bwd
|
||||
// = 0 (masked-out)
|
||||
//
|
||||
// kUseSoftmax=true (Softmax):
|
||||
// D[sq] = dO[sq] . O[sq] (uses forward output O directly)
|
||||
// D[sq] = dO[sq] row(.) O[sq] (uses forward output O directly)
|
||||
// dS[sq,sk] = P[sq,sk] * (dP[sq,sk] - D[sq])
|
||||
// ------------------------------------------------------------------
|
||||
std::vector<CompDataType> locals_dS(seqlen_kv);
|
||||
@@ -791,7 +779,7 @@ struct reference_group_hstu_attention_bwd
|
||||
}
|
||||
else
|
||||
{
|
||||
// D[sq] = dO[sq] . O[sq]
|
||||
// D[sq] = dO[sq] row(.) O[sq]
|
||||
GemmAccDataType D_acc = 0.f;
|
||||
for(int k = 0; k < hdim_v; k++)
|
||||
{
|
||||
@@ -808,8 +796,7 @@ struct reference_group_hstu_attention_bwd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Step 5: Compute dQ[sq,k] = alpha * sum_sk dS[sq,sk] * K[sk,k]
|
||||
// dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
|
||||
// Step 5: dQ = alpha * dS @ K^T (A=dS[sq,sk], B=K^T[hdim_qk,sk])
|
||||
// (computed fresh per sq row, no accumulation needed)
|
||||
// ------------------------------------------------------------------
|
||||
for(int k = 0; k < hdim_qk; k++)
|
||||
@@ -829,8 +816,7 @@ struct reference_group_hstu_attention_bwd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Step 6: Accumulate dK[sk,k] += alpha * dS[sq,sk] * Q[sq,k] (over sq)
|
||||
// dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
|
||||
// Step 6: dK = alpha * dS^T @ Q^T (A=dS^T[sk,sq], B=Q^T[hdim_qk,sq])
|
||||
// ------------------------------------------------------------------
|
||||
for(int sk = 0; sk < seqlen_kv; sk++)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user