CK-UA: split fmha_alu0 into rowmax/shift lambdas (default-off pipelining hook)

Refactors fmha_alu0 into fmha_alu0_rowmax (max3 reduce + cross-lane swap +
packed shift-coefficient precompute) and fmha_alu0_shift (the v_pk_fma sweep),
with a combined fmha_alu0 for the peeled/tail iterations. Bit-identical
(verified: split-on and split-off produce byte-for-byte identical standalone
output and PASS the fp8 accuracy check).

UA_FA4_SPLIT_ROWMAX (default 0) issues the rowmax ahead of the PV gemm in the
steady-state ping-pong so the MFMA cluster can cover the reduce->shift-addend
chain. MEASURED NEUTRAL (1782 TF/s either way) on the canonical fp8 prefill
shape: the post-RA scheduler already groups instructions per the core-loop
sched_group_barrier hints, so source-order reordering does not move the emitted
schedule. Kept as a structural hook; the shift stall is the QK-MFMA-result wait,
not the addend.

Context: ATT phase profiling shows softmax compute is largely hidden under the
ping-pong (matrix is only ~11% of the wave); the wall-time gate is the
barrier+memwait rendezvous (~30% of the wave) on K/V DRAM latency. exp2
Schraudolph approx (UA_FA4_EXP2_APPROX=1) is -11% here and stays off. See
ua-test-scripts/kv128_vgpr_findings.md for the full breakdown.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-06-13 11:59:21 +00:00
parent 29e0f75e19
commit d912139ca9

View File

@@ -129,6 +129,34 @@
#define UA_FA4_PACKED_ALU1_RESCALE 1
#endif
// UA_FA4_SPLIT_ROWMAX: pipeline the rowmax ahead of the score shift.
//
// fmha_alu0 is logically rowmax(max3 tree + cross-lane swap -> per-thread max)
// followed by the score shift (v_pk_fma: scale_s*(S - max)). The 32 packed shifts
// all depend on the single FMA addend (-scale_s*max), which is the tail of the
// rowmax reduction chain. With the monolithic alu0 emitted right after the PV
// gemm, the shifts have no independent work to cover that addend latency and
// stall on it (ATT: the shift v_pk_fma is the #1 softmax stall source, ~56% of
// softmax stall cycles on the canonical fp8 prefill shape).
//
// When set, the steady-state ping-pong (cl_calc) issues fmha_alu0_rowmax for the
// next sp tile BEFORE the PV gemm and fmha_alu0_shift AFTER it, so the PV MFMA
// cluster hides the rowmax->addend chain and the shift issues with its addend
// already resolved. Mirrors the reference softmax, which computes the rowmax
// inside gemm_QK. Bit-identical (pure reordering of independent work).
//
// MEASURED NEUTRAL (1782 TF/s either way) on the canonical fp8 prefill shape:
// the post-RA scheduler already groups instructions per the core-loop
// sched_group_barrier hints, so source-level reordering does not change the
// emitted schedule. The shift's stall is the wait on the QK MFMA result
// (sp_compute), not the rowmax addend. Kept (default off) as a structural hook;
// moving the rowmax earlier in the *schedule* requires editing the
// sched_group_barrier hints, not the call order. fmha_alu0 stays monolithic in
// the hot path when off (the original known-good schedule).
#ifndef UA_FA4_SPLIT_ROWMAX
#define UA_FA4_SPLIT_ROWMAX 0
#endif
// CONDITIONAL_RESCALE (PLAN_conditional_rescale Part 2)
// 0 (default): always-rescale online softmax — the o_acc/l accumulators are
// renormalised to the true running max `m` every KV tile (the expensive
@@ -1656,6 +1684,12 @@ struct UnifiedAttentionPipeline
decltype(m) m_old;
SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd()
// Packed score-shift coefficients (scale_s, -scale_s*max), computed in
// fmha_alu0_rowmax and consumed by fmha_alu0_shift when the two are
// split across the PV gemm (UA_FA4_SPLIT_ROWMAX). Both lanes hold the
// same scalar (the addend is per-thread uniform).
fp32x2_t ua_shift_scale_pair{};
fp32x2_t ua_shift_addend_pair{};
/// TODO: remove the sp_delta and use sp_compute directly
// sp_delta follows sp: single slot for the kv128 tile, double otherwise.
struct sp_delta_holder_t
@@ -1675,7 +1709,11 @@ struct UnifiedAttentionPipeline
constexpr bool kUseExp2Approx =
(UA_FA4_EXP2_APPROX != 0) && (UA_FA4_PACKED_SHIFT != 0) && !FmhaMask::IsMasking;
auto fmha_alu0 = [&](auto sp_reg_idx) {
// Rowmax half of fmha_alu0: reduce sp_compute -> per-thread running max
// `m`, then precompute the (uniform) packed shift coefficients. Split out
// so the steady-state loop can hoist it ahead of the PV gemm and hide the
// reduce->addend latency under the MFMA cluster (UA_FA4_SPLIT_ROWMAX).
auto fmha_alu0_rowmax = [&](auto sp_reg_idx) {
m_old = m; // m{j-1}
static_assert(m.thread_buf_.size() == 1,
"assuming that each thread holds 1 rowmax value");
@@ -1732,36 +1770,53 @@ struct UnifiedAttentionPipeline
#endif
#if UA_FA4_PACKED_SHIFT
// Packed score shift: each thread holds exactly one rowmax, so the FMA
// addend (-scale_s * rowmax) is uniform across the thread's score
// elements. Broadcast scale_s and the addend into both packed lanes and
// emit v_pk_fma_f32 (2 f32/instr) over sp_compute.thread_buf_ pairs.
// Bit-identical to the scalar fma_impl_vsv sweep below. The
// one-rowmax-per-thread invariant is asserted on `m` above.
static_assert(sp(sp_reg_idx).sp_compute.thread_buf_.size() % 2 == 0,
"packed shift needs an even score-register count");
// Precompute the (uniform) packed shift coefficients so fmha_alu0_shift
// can apply them later, after the PV gemm has hidden this reduce chain.
// Schraudolph fold: bits = S*(scale_s*2^23) + (-scale_s*2^23*max + bias)
// = 2^23*scale_s*(S-max) + bias, finished by v_cvt_u32_f32 in fmha_alu1.
// Exact path: sp_delta = scale_s*(S-max).
{
// Schraudolph fold: bits = S*(scale_s*2^23) + (-scale_s*2^23*max +
// bias) = 2^23*scale_s*(S-max) + bias, finished by v_cvt_u32_f32 in
// fmha_alu1. Exact path: sp_delta = scale_s*(S-max).
const float eff_scale =
kUseExp2Approx ? (scale_s * UA_EXP2_SCHRAUDOLPH_SCALE) : scale_s;
const float addend =
kUseExp2Approx
? (-eff_scale * m_shift.thread_buf_[0] + UA_EXP2_SCHRAUDOLPH_BIAS)
: (-scale_s * m_shift.thread_buf_[0]);
const fp32x2_t scale_pair{eff_scale, eff_scale};
const fp32x2_t addend_pair{addend, addend};
static_for<0, sp(sp_reg_idx).sp_compute.thread_buf_.size(), 2>{}([&](auto idx) {
fp32x2_t in;
in.x = sp(sp_reg_idx).sp_compute.thread_buf_[idx];
in.y = sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1];
auto out = detail::pk_fma_f32(in, scale_pair, addend_pair);
sp_delta(sp_reg_idx).thread_buf_[idx] = out.x;
sp_delta(sp_reg_idx).thread_buf_[idx + 1] = out.y;
});
ua_shift_scale_pair = fp32x2_t{eff_scale, eff_scale};
ua_shift_addend_pair = fp32x2_t{addend, addend};
}
#else
(void)m_shift;
#endif
/// TODO: move some fmha_alu1() code here if necessary
};
// Shift half of fmha_alu0: sp_delta = scale_s*(S - max) using the packed
// coefficients computed in fmha_alu0_rowmax. Reads sp_compute (already
// retired) and the resolved addend, so it issues without stalling on the
// rowmax reduction chain when hoisted apart from it (UA_FA4_SPLIT_ROWMAX).
auto fmha_alu0_shift = [&](auto sp_reg_idx) {
#if UA_FA4_PACKED_SHIFT
// Each thread holds exactly one rowmax, so the FMA addend
// (-scale_s * rowmax) is uniform across the thread's score elements.
// Emit v_pk_fma_f32 (2 f32/instr) over sp_compute.thread_buf_ pairs.
// Bit-identical to the scalar fma_impl_vsv sweep below.
static_assert(sp(sp_reg_idx).sp_compute.thread_buf_.size() % 2 == 0,
"packed shift needs an even score-register count");
static_for<0, sp(sp_reg_idx).sp_compute.thread_buf_.size(), 2>{}([&](auto idx) {
fp32x2_t in;
in.x = sp(sp_reg_idx).sp_compute.thread_buf_[idx];
in.y = sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1];
auto out = detail::pk_fma_f32(in, ua_shift_scale_pair, ua_shift_addend_pair);
sp_delta(sp_reg_idx).thread_buf_[idx] = out.x;
sp_delta(sp_reg_idx).thread_buf_[idx + 1] = out.y;
});
#else
#if CONDITIONAL_RESCALE
auto& m_shift = kCondRescale ? m_commit : m;
#else
auto& m_shift = m;
#endif
constexpr auto p_spans =
std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
@@ -1772,7 +1827,13 @@ struct UnifiedAttentionPipeline
});
});
#endif
/// TODO: move some fmha_alu1() code here if necessary
};
// Combined alu0 (rowmax + shift), used by the peeled/tail iterations where
// there is no PV gemm to hoist the rowmax under.
auto fmha_alu0 = [&](auto sp_reg_idx) {
fmha_alu0_rowmax(sp_reg_idx);
fmha_alu0_shift(sp_reg_idx);
};
auto fmha_alu1 = [&](auto sp_reg_idx) {
@@ -2246,6 +2307,20 @@ struct UnifiedAttentionPipeline
}
else
{
#if UA_FA4_SPLIT_ROWMAX
// Hoist the next tile's rowmax ahead of the PV gemm so the MFMA
// cluster hides the reduce->shift-addend chain; the shift then
// issues after the gemm with its addend already resolved.
fmha_alu0_rowmax(number<1>{} - sp_reg_idx);
gemm_1(o_acc,
get_slice_tile(sp(sp_reg_idx).p,
sequence<0, (k1_loops - 1) * kPageBlockSize>{},
sequence<kBlockM, k1_loops * kPageBlockSize>{}),
get_slice_tile(kv_tile.v_tile,
sequence<0, (k1_loops - 1) * kPageBlockSize>{},
sequence<kHeadDimPadded, k1_loops * kPageBlockSize>{}));
fmha_alu0_shift(number<1>{} - sp_reg_idx);
#else
gemm_1(o_acc,
get_slice_tile(sp(sp_reg_idx).p,
sequence<0, (k1_loops - 1) * kPageBlockSize>{},
@@ -2254,6 +2329,7 @@ struct UnifiedAttentionPipeline
sequence<0, (k1_loops - 1) * kPageBlockSize>{},
sequence<kHeadDimPadded, k1_loops * kPageBlockSize>{}));
fmha_alu0(number<1>{} - sp_reg_idx);
#endif
}
#if UA_DYNAMIC_SETPRIO
__builtin_amdgcn_s_setprio(0);