diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 34dc863ed1..d86965d4bc 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -129,6 +129,34 @@ #define UA_FA4_PACKED_ALU1_RESCALE 1 #endif +// UA_FA4_SPLIT_ROWMAX: pipeline the rowmax ahead of the score shift. +// +// fmha_alu0 is logically rowmax(max3 tree + cross-lane swap -> per-thread max) +// followed by the score shift (v_pk_fma: scale_s*(S - max)). The 32 packed shifts +// all depend on the single FMA addend (-scale_s*max), which is the tail of the +// rowmax reduction chain. With the monolithic alu0 emitted right after the PV +// gemm, the shifts have no independent work to cover that addend latency and +// stall on it (ATT: the shift v_pk_fma is the #1 softmax stall source, ~56% of +// softmax stall cycles on the canonical fp8 prefill shape). +// +// When set, the steady-state ping-pong (cl_calc) issues fmha_alu0_rowmax for the +// next sp tile BEFORE the PV gemm and fmha_alu0_shift AFTER it, so the PV MFMA +// cluster hides the rowmax->addend chain and the shift issues with its addend +// already resolved. Mirrors the reference softmax, which computes the rowmax +// inside gemm_QK. Bit-identical (pure reordering of independent work). +// +// MEASURED NEUTRAL (1782 TF/s either way) on the canonical fp8 prefill shape: +// the post-RA scheduler already groups instructions per the core-loop +// sched_group_barrier hints, so source-level reordering does not change the +// emitted schedule. The shift's stall is the wait on the QK MFMA result +// (sp_compute), not the rowmax addend. Kept (default off) as a structural hook; +// moving the rowmax earlier in the *schedule* requires editing the +// sched_group_barrier hints, not the call order. fmha_alu0 stays monolithic in +// the hot path when off (the original known-good schedule). +#ifndef UA_FA4_SPLIT_ROWMAX +#define UA_FA4_SPLIT_ROWMAX 0 +#endif + // CONDITIONAL_RESCALE (PLAN_conditional_rescale Part 2) // 0 (default): always-rescale online softmax — the o_acc/l accumulators are // renormalised to the true running max `m` every KV tile (the expensive @@ -1656,6 +1684,12 @@ struct UnifiedAttentionPipeline decltype(m) m_old; SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd() + // Packed score-shift coefficients (scale_s, -scale_s*max), computed in + // fmha_alu0_rowmax and consumed by fmha_alu0_shift when the two are + // split across the PV gemm (UA_FA4_SPLIT_ROWMAX). Both lanes hold the + // same scalar (the addend is per-thread uniform). + fp32x2_t ua_shift_scale_pair{}; + fp32x2_t ua_shift_addend_pair{}; /// TODO: remove the sp_delta and use sp_compute directly // sp_delta follows sp: single slot for the kv128 tile, double otherwise. struct sp_delta_holder_t @@ -1675,7 +1709,11 @@ struct UnifiedAttentionPipeline constexpr bool kUseExp2Approx = (UA_FA4_EXP2_APPROX != 0) && (UA_FA4_PACKED_SHIFT != 0) && !FmhaMask::IsMasking; - auto fmha_alu0 = [&](auto sp_reg_idx) { + // Rowmax half of fmha_alu0: reduce sp_compute -> per-thread running max + // `m`, then precompute the (uniform) packed shift coefficients. Split out + // so the steady-state loop can hoist it ahead of the PV gemm and hide the + // reduce->addend latency under the MFMA cluster (UA_FA4_SPLIT_ROWMAX). + auto fmha_alu0_rowmax = [&](auto sp_reg_idx) { m_old = m; // m{j-1} static_assert(m.thread_buf_.size() == 1, "assuming that each thread holds 1 rowmax value"); @@ -1732,36 +1770,53 @@ struct UnifiedAttentionPipeline #endif #if UA_FA4_PACKED_SHIFT - // Packed score shift: each thread holds exactly one rowmax, so the FMA - // addend (-scale_s * rowmax) is uniform across the thread's score - // elements. Broadcast scale_s and the addend into both packed lanes and - // emit v_pk_fma_f32 (2 f32/instr) over sp_compute.thread_buf_ pairs. - // Bit-identical to the scalar fma_impl_vsv sweep below. The - // one-rowmax-per-thread invariant is asserted on `m` above. - static_assert(sp(sp_reg_idx).sp_compute.thread_buf_.size() % 2 == 0, - "packed shift needs an even score-register count"); + // Precompute the (uniform) packed shift coefficients so fmha_alu0_shift + // can apply them later, after the PV gemm has hidden this reduce chain. + // Schraudolph fold: bits = S*(scale_s*2^23) + (-scale_s*2^23*max + bias) + // = 2^23*scale_s*(S-max) + bias, finished by v_cvt_u32_f32 in fmha_alu1. + // Exact path: sp_delta = scale_s*(S-max). { - // Schraudolph fold: bits = S*(scale_s*2^23) + (-scale_s*2^23*max + - // bias) = 2^23*scale_s*(S-max) + bias, finished by v_cvt_u32_f32 in - // fmha_alu1. Exact path: sp_delta = scale_s*(S-max). const float eff_scale = kUseExp2Approx ? (scale_s * UA_EXP2_SCHRAUDOLPH_SCALE) : scale_s; const float addend = kUseExp2Approx ? (-eff_scale * m_shift.thread_buf_[0] + UA_EXP2_SCHRAUDOLPH_BIAS) : (-scale_s * m_shift.thread_buf_[0]); - const fp32x2_t scale_pair{eff_scale, eff_scale}; - const fp32x2_t addend_pair{addend, addend}; - static_for<0, sp(sp_reg_idx).sp_compute.thread_buf_.size(), 2>{}([&](auto idx) { - fp32x2_t in; - in.x = sp(sp_reg_idx).sp_compute.thread_buf_[idx]; - in.y = sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]; - auto out = detail::pk_fma_f32(in, scale_pair, addend_pair); - sp_delta(sp_reg_idx).thread_buf_[idx] = out.x; - sp_delta(sp_reg_idx).thread_buf_[idx + 1] = out.y; - }); + ua_shift_scale_pair = fp32x2_t{eff_scale, eff_scale}; + ua_shift_addend_pair = fp32x2_t{addend, addend}; } #else + (void)m_shift; +#endif + /// TODO: move some fmha_alu1() code here if necessary + }; + + // Shift half of fmha_alu0: sp_delta = scale_s*(S - max) using the packed + // coefficients computed in fmha_alu0_rowmax. Reads sp_compute (already + // retired) and the resolved addend, so it issues without stalling on the + // rowmax reduction chain when hoisted apart from it (UA_FA4_SPLIT_ROWMAX). + auto fmha_alu0_shift = [&](auto sp_reg_idx) { +#if UA_FA4_PACKED_SHIFT + // Each thread holds exactly one rowmax, so the FMA addend + // (-scale_s * rowmax) is uniform across the thread's score elements. + // Emit v_pk_fma_f32 (2 f32/instr) over sp_compute.thread_buf_ pairs. + // Bit-identical to the scalar fma_impl_vsv sweep below. + static_assert(sp(sp_reg_idx).sp_compute.thread_buf_.size() % 2 == 0, + "packed shift needs an even score-register count"); + static_for<0, sp(sp_reg_idx).sp_compute.thread_buf_.size(), 2>{}([&](auto idx) { + fp32x2_t in; + in.x = sp(sp_reg_idx).sp_compute.thread_buf_[idx]; + in.y = sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]; + auto out = detail::pk_fma_f32(in, ua_shift_scale_pair, ua_shift_addend_pair); + sp_delta(sp_reg_idx).thread_buf_[idx] = out.x; + sp_delta(sp_reg_idx).thread_buf_[idx + 1] = out.y; + }); +#else +#if CONDITIONAL_RESCALE + auto& m_shift = kCondRescale ? m_commit : m; +#else + auto& m_shift = m; +#endif constexpr auto p_spans = std::decay_t::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { @@ -1772,7 +1827,13 @@ struct UnifiedAttentionPipeline }); }); #endif - /// TODO: move some fmha_alu1() code here if necessary + }; + + // Combined alu0 (rowmax + shift), used by the peeled/tail iterations where + // there is no PV gemm to hoist the rowmax under. + auto fmha_alu0 = [&](auto sp_reg_idx) { + fmha_alu0_rowmax(sp_reg_idx); + fmha_alu0_shift(sp_reg_idx); }; auto fmha_alu1 = [&](auto sp_reg_idx) { @@ -2246,6 +2307,20 @@ struct UnifiedAttentionPipeline } else { +#if UA_FA4_SPLIT_ROWMAX + // Hoist the next tile's rowmax ahead of the PV gemm so the MFMA + // cluster hides the reduce->shift-addend chain; the shift then + // issues after the gemm with its addend already resolved. + fmha_alu0_rowmax(number<1>{} - sp_reg_idx); + gemm_1(o_acc, + get_slice_tile(sp(sp_reg_idx).p, + sequence<0, (k1_loops - 1) * kPageBlockSize>{}, + sequence{}), + get_slice_tile(kv_tile.v_tile, + sequence<0, (k1_loops - 1) * kPageBlockSize>{}, + sequence{})); + fmha_alu0_shift(number<1>{} - sp_reg_idx); +#else gemm_1(o_acc, get_slice_tile(sp(sp_reg_idx).p, sequence<0, (k1_loops - 1) * kPageBlockSize>{}, @@ -2254,6 +2329,7 @@ struct UnifiedAttentionPipeline sequence<0, (k1_loops - 1) * kPageBlockSize>{}, sequence{})); fmha_alu0(number<1>{} - sp_reg_idx); +#endif } #if UA_DYNAMIC_SETPRIO __builtin_amdgcn_s_setprio(0);