mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
CK-UA: split fmha_alu0 into rowmax/shift lambdas (default-off pipelining hook)
Refactors fmha_alu0 into fmha_alu0_rowmax (max3 reduce + cross-lane swap + packed shift-coefficient precompute) and fmha_alu0_shift (the v_pk_fma sweep), with a combined fmha_alu0 for the peeled/tail iterations. Bit-identical (verified: split-on and split-off produce byte-for-byte identical standalone output and PASS the fp8 accuracy check). UA_FA4_SPLIT_ROWMAX (default 0) issues the rowmax ahead of the PV gemm in the steady-state ping-pong so the MFMA cluster can cover the reduce->shift-addend chain. MEASURED NEUTRAL (1782 TF/s either way) on the canonical fp8 prefill shape: the post-RA scheduler already groups instructions per the core-loop sched_group_barrier hints, so source-order reordering does not move the emitted schedule. Kept as a structural hook; the shift stall is the QK-MFMA-result wait, not the addend. Context: ATT phase profiling shows softmax compute is largely hidden under the ping-pong (matrix is only ~11% of the wave); the wall-time gate is the barrier+memwait rendezvous (~30% of the wave) on K/V DRAM latency. exp2 Schraudolph approx (UA_FA4_EXP2_APPROX=1) is -11% here and stays off. See ua-test-scripts/kv128_vgpr_findings.md for the full breakdown. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -129,6 +129,34 @@
|
||||
#define UA_FA4_PACKED_ALU1_RESCALE 1
|
||||
#endif
|
||||
|
||||
// UA_FA4_SPLIT_ROWMAX: pipeline the rowmax ahead of the score shift.
|
||||
//
|
||||
// fmha_alu0 is logically rowmax(max3 tree + cross-lane swap -> per-thread max)
|
||||
// followed by the score shift (v_pk_fma: scale_s*(S - max)). The 32 packed shifts
|
||||
// all depend on the single FMA addend (-scale_s*max), which is the tail of the
|
||||
// rowmax reduction chain. With the monolithic alu0 emitted right after the PV
|
||||
// gemm, the shifts have no independent work to cover that addend latency and
|
||||
// stall on it (ATT: the shift v_pk_fma is the #1 softmax stall source, ~56% of
|
||||
// softmax stall cycles on the canonical fp8 prefill shape).
|
||||
//
|
||||
// When set, the steady-state ping-pong (cl_calc) issues fmha_alu0_rowmax for the
|
||||
// next sp tile BEFORE the PV gemm and fmha_alu0_shift AFTER it, so the PV MFMA
|
||||
// cluster hides the rowmax->addend chain and the shift issues with its addend
|
||||
// already resolved. Mirrors the reference softmax, which computes the rowmax
|
||||
// inside gemm_QK. Bit-identical (pure reordering of independent work).
|
||||
//
|
||||
// MEASURED NEUTRAL (1782 TF/s either way) on the canonical fp8 prefill shape:
|
||||
// the post-RA scheduler already groups instructions per the core-loop
|
||||
// sched_group_barrier hints, so source-level reordering does not change the
|
||||
// emitted schedule. The shift's stall is the wait on the QK MFMA result
|
||||
// (sp_compute), not the rowmax addend. Kept (default off) as a structural hook;
|
||||
// moving the rowmax earlier in the *schedule* requires editing the
|
||||
// sched_group_barrier hints, not the call order. fmha_alu0 stays monolithic in
|
||||
// the hot path when off (the original known-good schedule).
|
||||
#ifndef UA_FA4_SPLIT_ROWMAX
|
||||
#define UA_FA4_SPLIT_ROWMAX 0
|
||||
#endif
|
||||
|
||||
// CONDITIONAL_RESCALE (PLAN_conditional_rescale Part 2)
|
||||
// 0 (default): always-rescale online softmax — the o_acc/l accumulators are
|
||||
// renormalised to the true running max `m` every KV tile (the expensive
|
||||
@@ -1656,6 +1684,12 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
decltype(m) m_old;
|
||||
SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd()
|
||||
// Packed score-shift coefficients (scale_s, -scale_s*max), computed in
|
||||
// fmha_alu0_rowmax and consumed by fmha_alu0_shift when the two are
|
||||
// split across the PV gemm (UA_FA4_SPLIT_ROWMAX). Both lanes hold the
|
||||
// same scalar (the addend is per-thread uniform).
|
||||
fp32x2_t ua_shift_scale_pair{};
|
||||
fp32x2_t ua_shift_addend_pair{};
|
||||
/// TODO: remove the sp_delta and use sp_compute directly
|
||||
// sp_delta follows sp: single slot for the kv128 tile, double otherwise.
|
||||
struct sp_delta_holder_t
|
||||
@@ -1675,7 +1709,11 @@ struct UnifiedAttentionPipeline
|
||||
constexpr bool kUseExp2Approx =
|
||||
(UA_FA4_EXP2_APPROX != 0) && (UA_FA4_PACKED_SHIFT != 0) && !FmhaMask::IsMasking;
|
||||
|
||||
auto fmha_alu0 = [&](auto sp_reg_idx) {
|
||||
// Rowmax half of fmha_alu0: reduce sp_compute -> per-thread running max
|
||||
// `m`, then precompute the (uniform) packed shift coefficients. Split out
|
||||
// so the steady-state loop can hoist it ahead of the PV gemm and hide the
|
||||
// reduce->addend latency under the MFMA cluster (UA_FA4_SPLIT_ROWMAX).
|
||||
auto fmha_alu0_rowmax = [&](auto sp_reg_idx) {
|
||||
m_old = m; // m{j-1}
|
||||
static_assert(m.thread_buf_.size() == 1,
|
||||
"assuming that each thread holds 1 rowmax value");
|
||||
@@ -1732,36 +1770,53 @@ struct UnifiedAttentionPipeline
|
||||
#endif
|
||||
|
||||
#if UA_FA4_PACKED_SHIFT
|
||||
// Packed score shift: each thread holds exactly one rowmax, so the FMA
|
||||
// addend (-scale_s * rowmax) is uniform across the thread's score
|
||||
// elements. Broadcast scale_s and the addend into both packed lanes and
|
||||
// emit v_pk_fma_f32 (2 f32/instr) over sp_compute.thread_buf_ pairs.
|
||||
// Bit-identical to the scalar fma_impl_vsv sweep below. The
|
||||
// one-rowmax-per-thread invariant is asserted on `m` above.
|
||||
static_assert(sp(sp_reg_idx).sp_compute.thread_buf_.size() % 2 == 0,
|
||||
"packed shift needs an even score-register count");
|
||||
// Precompute the (uniform) packed shift coefficients so fmha_alu0_shift
|
||||
// can apply them later, after the PV gemm has hidden this reduce chain.
|
||||
// Schraudolph fold: bits = S*(scale_s*2^23) + (-scale_s*2^23*max + bias)
|
||||
// = 2^23*scale_s*(S-max) + bias, finished by v_cvt_u32_f32 in fmha_alu1.
|
||||
// Exact path: sp_delta = scale_s*(S-max).
|
||||
{
|
||||
// Schraudolph fold: bits = S*(scale_s*2^23) + (-scale_s*2^23*max +
|
||||
// bias) = 2^23*scale_s*(S-max) + bias, finished by v_cvt_u32_f32 in
|
||||
// fmha_alu1. Exact path: sp_delta = scale_s*(S-max).
|
||||
const float eff_scale =
|
||||
kUseExp2Approx ? (scale_s * UA_EXP2_SCHRAUDOLPH_SCALE) : scale_s;
|
||||
const float addend =
|
||||
kUseExp2Approx
|
||||
? (-eff_scale * m_shift.thread_buf_[0] + UA_EXP2_SCHRAUDOLPH_BIAS)
|
||||
: (-scale_s * m_shift.thread_buf_[0]);
|
||||
const fp32x2_t scale_pair{eff_scale, eff_scale};
|
||||
const fp32x2_t addend_pair{addend, addend};
|
||||
static_for<0, sp(sp_reg_idx).sp_compute.thread_buf_.size(), 2>{}([&](auto idx) {
|
||||
fp32x2_t in;
|
||||
in.x = sp(sp_reg_idx).sp_compute.thread_buf_[idx];
|
||||
in.y = sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1];
|
||||
auto out = detail::pk_fma_f32(in, scale_pair, addend_pair);
|
||||
sp_delta(sp_reg_idx).thread_buf_[idx] = out.x;
|
||||
sp_delta(sp_reg_idx).thread_buf_[idx + 1] = out.y;
|
||||
});
|
||||
ua_shift_scale_pair = fp32x2_t{eff_scale, eff_scale};
|
||||
ua_shift_addend_pair = fp32x2_t{addend, addend};
|
||||
}
|
||||
#else
|
||||
(void)m_shift;
|
||||
#endif
|
||||
/// TODO: move some fmha_alu1() code here if necessary
|
||||
};
|
||||
|
||||
// Shift half of fmha_alu0: sp_delta = scale_s*(S - max) using the packed
|
||||
// coefficients computed in fmha_alu0_rowmax. Reads sp_compute (already
|
||||
// retired) and the resolved addend, so it issues without stalling on the
|
||||
// rowmax reduction chain when hoisted apart from it (UA_FA4_SPLIT_ROWMAX).
|
||||
auto fmha_alu0_shift = [&](auto sp_reg_idx) {
|
||||
#if UA_FA4_PACKED_SHIFT
|
||||
// Each thread holds exactly one rowmax, so the FMA addend
|
||||
// (-scale_s * rowmax) is uniform across the thread's score elements.
|
||||
// Emit v_pk_fma_f32 (2 f32/instr) over sp_compute.thread_buf_ pairs.
|
||||
// Bit-identical to the scalar fma_impl_vsv sweep below.
|
||||
static_assert(sp(sp_reg_idx).sp_compute.thread_buf_.size() % 2 == 0,
|
||||
"packed shift needs an even score-register count");
|
||||
static_for<0, sp(sp_reg_idx).sp_compute.thread_buf_.size(), 2>{}([&](auto idx) {
|
||||
fp32x2_t in;
|
||||
in.x = sp(sp_reg_idx).sp_compute.thread_buf_[idx];
|
||||
in.y = sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1];
|
||||
auto out = detail::pk_fma_f32(in, ua_shift_scale_pair, ua_shift_addend_pair);
|
||||
sp_delta(sp_reg_idx).thread_buf_[idx] = out.x;
|
||||
sp_delta(sp_reg_idx).thread_buf_[idx + 1] = out.y;
|
||||
});
|
||||
#else
|
||||
#if CONDITIONAL_RESCALE
|
||||
auto& m_shift = kCondRescale ? m_commit : m;
|
||||
#else
|
||||
auto& m_shift = m;
|
||||
#endif
|
||||
constexpr auto p_spans =
|
||||
std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -1772,7 +1827,13 @@ struct UnifiedAttentionPipeline
|
||||
});
|
||||
});
|
||||
#endif
|
||||
/// TODO: move some fmha_alu1() code here if necessary
|
||||
};
|
||||
|
||||
// Combined alu0 (rowmax + shift), used by the peeled/tail iterations where
|
||||
// there is no PV gemm to hoist the rowmax under.
|
||||
auto fmha_alu0 = [&](auto sp_reg_idx) {
|
||||
fmha_alu0_rowmax(sp_reg_idx);
|
||||
fmha_alu0_shift(sp_reg_idx);
|
||||
};
|
||||
|
||||
auto fmha_alu1 = [&](auto sp_reg_idx) {
|
||||
@@ -2246,6 +2307,20 @@ struct UnifiedAttentionPipeline
|
||||
}
|
||||
else
|
||||
{
|
||||
#if UA_FA4_SPLIT_ROWMAX
|
||||
// Hoist the next tile's rowmax ahead of the PV gemm so the MFMA
|
||||
// cluster hides the reduce->shift-addend chain; the shift then
|
||||
// issues after the gemm with its addend already resolved.
|
||||
fmha_alu0_rowmax(number<1>{} - sp_reg_idx);
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(sp(sp_reg_idx).p,
|
||||
sequence<0, (k1_loops - 1) * kPageBlockSize>{},
|
||||
sequence<kBlockM, k1_loops * kPageBlockSize>{}),
|
||||
get_slice_tile(kv_tile.v_tile,
|
||||
sequence<0, (k1_loops - 1) * kPageBlockSize>{},
|
||||
sequence<kHeadDimPadded, k1_loops * kPageBlockSize>{}));
|
||||
fmha_alu0_shift(number<1>{} - sp_reg_idx);
|
||||
#else
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(sp(sp_reg_idx).p,
|
||||
sequence<0, (k1_loops - 1) * kPageBlockSize>{},
|
||||
@@ -2254,6 +2329,7 @@ struct UnifiedAttentionPipeline
|
||||
sequence<0, (k1_loops - 1) * kPageBlockSize>{},
|
||||
sequence<kHeadDimPadded, k1_loops * kPageBlockSize>{}));
|
||||
fmha_alu0(number<1>{} - sp_reg_idx);
|
||||
#endif
|
||||
}
|
||||
#if UA_DYNAMIC_SETPRIO
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
|
||||
Reference in New Issue
Block a user