diff --git a/example/ck_tile/50_sparse_attn/README.md b/example/ck_tile/50_sparse_attn/README.md index c7191c8e82..9fdad906de 100644 --- a/example/ck_tile/50_sparse_attn/README.md +++ b/example/ck_tile/50_sparse_attn/README.md @@ -14,10 +14,23 @@ Not yet ported (upstream pinned to commit [`ae5b629`](https://github.com/thu-ml/ - **K smoothing** — pre-pool `k -= km`; required for diffusion / video checkpoints (CogVideoX, Mochi-1, Flux, OpenSora, SD 3.5) ([spas_sage_attn/core.py:L53](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L53)) - **is_causal mask in pooled score** — required for causal-LM prefill (Llama, Qwen) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338)) - **attention_sink** — column 0 forced ON; upstream is hard-wired to `True` at inference ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355)) -- **pv_threshold per-Q-tile skip in attn kernel** — pure perf, ~5–15% on the dominant attention slice ([spas_sage_attn/core.py:L265](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L265)) - **Sort-based top-k selection** — replaces our O(N_k^2) iterative argmax; matters at long seqlen (s ≥ 16k) ([spas_sage_attn/utils.py:L345](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L345)) - **Q/K int8 quant fusion in pool kernel** — enables a downstream int8 GEMM0 in the attn kernel ([spas_sage_attn/utils.py:L371](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L371)) +## PV-skip modes + +`pv_threshold` per-Q-tile skip in the attention kernel is implemented in three variants, selectable at runtime via `-pv_mode={none|warp|block}`: + +- **`none`** — skip disabled; baseline matching the no-PV-skip codegen instance. +- **`warp`** (per-wavefront) — each wavefront votes locally via `__shfl_xor` butterfly AND; SGPR-resident flag. CK-tile-specific variant, not in upstream. +- **`block`** (per-block) — block-wide consensus vote via LDS broadcast; aligned with upstream sm80 ([`qk_int_sv_f16_cuda_sm80.cuh:L334`](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/csrc/qattn/qk_int_sv_f16_cuda_sm80.cuh#L334)). V loads stay unconditional in all modes — the guard wraps the PV MMA only, matching upstream and paper Algorithm 1. + +![PV-skip mode comparison](docs/pv_skip_mode_comparison.png) + +*MI300X, b=2 h=16 s=8192 d=128 fp16, 5 seeds × 9 sparsity points. All three modes dispatch to the `kM0=64 padK=0` tile bucket at this shape.* + +On the canonical recipe shape, `none > warp > block` at every measured sparsity, with no crossover. The per-block guard adds +33..+35 VGPR (6..9 spills) on this tile configuration, depressing occupancy. `warp` is +0..+4 VGPR. The default is `-pv_mode=warp` (preserves R25 A1 behaviour); switch to `none` for the no-skip baseline or `block` to exercise the upstream-aligned variant. A shape sweep is needed before recommending `block` as default — the `kM0=128` path has Δ ≈ 0 VGPR for per-block and is a candidate. + ## Performance At b=2 h=32 s=16384 fp16, sparge (vsa backend) reaches **1.78× FMHA throughput at topk=0.4** and **5.04× at topk=0.1**, and stays above 1.0× across the full topk range. @@ -37,6 +50,8 @@ ninja tile_example_sparge ./bin/tile_example_sparge -pipeline=vsa -b=2 -h=32 -s=16384 -d=128 -topk=0.4 -simthreshd1=0.001 ``` +Select a PV-skip variant with `-pv_mode={none|warp|block}` (default `warp`); finite `-pv_threshold=20` lets the per-Q-tile skip predicate fire. + Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k. ## References diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py index 9489d3758f..e5182c3dc8 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py @@ -114,48 +114,55 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< {F_trload}, fmha_trait_{F_idx}>; -// R25 V0: instantiate the Sparge pipeline with kEnablePVSkip = true (existing path) -// AND kEnablePVSkip = false (PV-skip AST removed at compile time, source-equivalent -// to the frozen VSA reference). Both kernels live in the same TU; the host dispatch -// in fmha_sparge_fwd_api.cpp picks one based on fmha_sparge_fwd_args::pv_skip_compile. -using fmha_pipeline_{F_idx}_pvst = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< - fmha_pipeline_problem_{F_idx}, - ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, - true>; +// R30: emit 3 pipeline / kernel instances per traits combo — kNone (PV-skip +// AST removed; source-equivalent to VSA), kPerWave (R25 A1 shipped path), +// kPerBlock (R30 added: block-wide AND vote gates gemm_1). The host dispatch +// in fmha_sparge_fwd_api.cpp picks one based on +// fmha_sparge_fwd_args::pv_mode_compile (0/1/2). +// R26 split-launch: fmha_fwd_create_kargs_and_grids(a) forwards the new +// fmha_sparge_fwd_args fields (pv_threshold_per_head_ptr, head_remap_ptr, +// nhead_in_launch) to MakeKargs. When head_remap_ptr is non-null the wrapper +// also shrinks grids.y to nhead_in_launch so each bucket fires its own kernel. +// Suffixes: +// _pvsf = PV-Skip OFF (kNone) +// _pvst = PV-Skip per-WAVE (kPerWave; preserved R25 A1 binary name) +// _pvsb = PV-Skip per-BLOCK (kPerBlock; R30 new) using fmha_pipeline_{F_idx}_pvsf = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< fmha_pipeline_problem_{F_idx}, ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, - false>; + ck_tile::PVSkipMode::kNone>; +using fmha_pipeline_{F_idx}_pvst = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< + fmha_pipeline_problem_{F_idx}, + ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, + ck_tile::PVSkipMode::kPerWave>; +using fmha_pipeline_{F_idx}_pvsb = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< + fmha_pipeline_problem_{F_idx}, + ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, + ck_tile::PVSkipMode::kPerBlock>; using fmha_epilogue_{F_idx} = ck_tile::Default2DEpilogue::OaccDataType, typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, {F_spad}, {F_dvpad}>>; -using fmha_kernel_{F_idx}_pvst = - ck_tile::FmhaFwdSpargeKernel; using fmha_kernel_{F_idx}_pvsf = - ck_tile::FmhaFwdSpargeKernel; + ck_tile::FmhaFwdSpargeKernel; +using fmha_kernel_{F_idx}_pvst = + ck_tile::FmhaFwdSpargeKernel; +using fmha_kernel_{F_idx}_pvsb = + ck_tile::FmhaFwdSpargeKernel; using trait_{F_idx} = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; #include +// R30: 3 specializations per traits combo — int kPVMode values: +// 0 = kNone (pvsf binary) +// 1 = kPerWave (pvst binary; R25 A1 path) +// 2 = kPerBlock (pvsb binary; R30 new) template<> -float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) -{{ - using k_ = fmha_kernel_{F_idx}_pvst; - if(s.log_level_ > 0) - std::cout << ", " << "{F_kernel_name}_pvst" << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} - -template<> -float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}_pvsf; if(s.log_level_ > 0) @@ -167,7 +174,42 @@ float fmha_sparge_fwd_(const ck_tile::stream_config& s, fm }} template<> -void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvst; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}_pvst" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvsb; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}_pvsb" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvsf; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}_pvst; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); @@ -178,9 +220,9 @@ void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& }} template<> -void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) {{ - using k_ = fmha_kernel_{F_idx}_pvsf; + using k_ = fmha_kernel_{F_idx}_pvsb; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; @@ -261,10 +303,13 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - if(a.pv_skip_compile) - return fmha_sparge_fwd_(s, a); - else - return fmha_sparge_fwd_(s, a); + // R30: pv_mode_compile selects 0=kNone / 1=kPerWave / 2=kPerBlock. + switch(a.pv_mode_compile) {{ + case 0: return fmha_sparge_fwd_(s, a); + case 1: return fmha_sparge_fwd_(s, a); + case 2: return fmha_sparge_fwd_(s, a); + default: return fmha_sparge_fwd_(s, a); // legacy default = per-wave + }} }} """ @@ -302,11 +347,13 @@ FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - if(a.pv_skip_compile) - fmha_sparge_fwd_oneshot_(s, a); - else - fmha_sparge_fwd_oneshot_(s, a); - return; + // R30: pv_mode_compile selects 0=kNone / 1=kPerWave / 2=kPerBlock. + switch(a.pv_mode_compile) {{ + case 0: fmha_sparge_fwd_oneshot_(s, a); return; + case 1: fmha_sparge_fwd_oneshot_(s, a); return; + case 2: fmha_sparge_fwd_oneshot_(s, a); return; + default: fmha_sparge_fwd_oneshot_(s, a); return; + }} }} """ diff --git a/example/ck_tile/50_sparse_attn/docs/pv_skip_mode_comparison.png b/example/ck_tile/50_sparse_attn/docs/pv_skip_mode_comparison.png new file mode 100644 index 0000000000..b35c20a679 Binary files /dev/null and b/example/ck_tile/50_sparse_attn/docs/pv_skip_mode_comparison.png differ diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 8339b50389..071d0409b0 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -355,7 +355,15 @@ struct fmha_sparge_fwd_args ck_tile::index_t nhead_k; float scale_s; - float pv_threshold; // SpargeAttn §4.4 PV-skip per-Q-tile threshold + float pv_threshold; // SpargeAttn §4.4 PV-skip per-Q-tile threshold (scalar mode) + + // R26 split-launch: when non-null, per-head pv_threshold buffer (length nhead_q) + // is read on device instead of the scalar. Combined with head_remap_ptr the + // host can issue two launches (finite-threshold bucket + sentinel bucket) at + // different binaries. + const float* pv_threshold_per_head_ptr = nullptr; + const int* head_remap_ptr = nullptr; + int nhead_in_launch = 0; // 0 = identity (full nhead_q grid) ck_tile::index_t stride_q; ck_tile::index_t stride_k; @@ -379,7 +387,18 @@ struct fmha_sparge_fwd_args // shipped pre-R25-V0 only had the true instance). Profiler can flip this to // false to measure the source-equivalent-to-VSA baseline (`if constexpr` // removes the entire PV-skip AST). + // + // R30: superseded by pv_mode_compile (int 0/1/2). Kept for source compat — + // when callers only set pv_skip_compile, the split-launch wrapper derives + // pv_mode_compile = (pv_skip_compile ? 1 : 0). bool pv_skip_compile = true; + + // R30: 3-mode PV-skip select. + // 0 = kNone (no PV-skip; AST removed; equivalent to VSA baseline) + // 1 = kPerWave (R25 A1 shipped path; per-wavefront butterfly vote) + // 2 = kPerBlock (R30 added; block-wide AND vote through 1 LDS slot) + // Default 1 preserves R25 A1 behaviour for any caller that doesn't set it. + int pv_mode_compile = 1; }; template @@ -414,9 +433,17 @@ auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + // R26 split-launch extras + args.pv_threshold_per_head_ptr, + args.head_remap_ptr, + args.nhead_in_launch); - dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + // R26 split-launch: when head_remap is active, gridDim.y shrinks to bucket size. + const ck_tile::index_t grid_nhead = (args.head_remap_ptr != nullptr && args.nhead_in_launch > 0) + ? args.nhead_in_launch + : args.nhead_q; + dim3 grids = FmhaKernel::GridSize(args.batch, grid_nhead, args.max_seqlen_q, args.hdim_v); return ck_tile::make_tuple(kargs, grids); } @@ -459,14 +486,15 @@ using fmha_sparge_fwd_traits = fmha_jenga_fwd_traits; float fmha_sparge_fwd(fmha_sparge_fwd_traits, fmha_sparge_fwd_args, const ck_tile::stream_config&); -// R25 V0: kEnablePVSkip is now a template non-type param so the codegen can -// emit both true / false instantiations from the same source tree. The host -// dispatch (fmha_sparge_fwd_api.cpp) selects the right specialization based -// on fmha_sparge_fwd_args::pv_skip_compile at runtime. -template +// R25 V0 / R30: PV-skip mode is a template non-type param so codegen can emit +// all 3 instantiations from the same source tree. The host dispatch +// (fmha_sparge_fwd_api.cpp) selects the right specialization based on +// fmha_sparge_fwd_args::pv_mode_compile at runtime. +// 0 = kNone, 1 = kPerWave, 2 = kPerBlock (matches ck_tile::PVSkipMode). +template float fmha_sparge_fwd_(const ck_tile::stream_config&, fmha_sparge_fwd_args); -template +template void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config&, fmha_sparge_fwd_args); void fmha_sparge_fwd_oneshot(fmha_sparge_fwd_traits, diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp index 06be6215bc..10a58ae05f 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -4,11 +4,14 @@ #include "sparge_blockmap_trek.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/host/device_memory.hpp" #include +#include #include #include #include +#include // ============================================================================ // Type configuration for block map kernel (reuses FmhaSparseFwdTypeConfig) @@ -265,6 +268,21 @@ float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, [=](const ck_tile::stream_config& s_) { fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); }); } +// R26 split-launch: partition heads into two buckets by per-head pv_threshold +// (sentinel >= 1e29f vs finite), materialise device-side remap LUTs, then issue +// one fmha launch per non-empty bucket. Bucket selection happens entirely on +// the host; the kernel just reads head_remap_ptr[blockIdx.y] to recover the +// original head index. +// +// R30: the "finite" bucket binary is selected by attn_a.pv_mode_compile (1 = +// per-wave (R25 A1 default); 2 = per-block (R30)). The "sentinel" bucket is +// always kNone (mode 0) — sentinel heads requested PV-skip OFF so the per-head +// per-mode choice is degenerate. Per-head per-mode bucket (3+ buckets) is a +// future R31 extension; this commit keeps the 2-bucket scheme and routes the +// active mode through attn_true.pv_mode_compile. +// +// Backward compat: if pv_threshold_per_head_ptr is null, fall back to the +// original single-launch path using attn_a.pv_threshold scalar. float sparge_sparge_fwd_combined(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a, fmha_sparge_fwd_traits attn_t, @@ -277,11 +295,115 @@ float sparge_sparge_fwd_combined(sparge_blockmap_traits bmap_t, << ", fmha_sparge_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q << std::flush; - return ck_tile::launch_kernel( - s, - [=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); }, - [=](const ck_tile::stream_config& s_) { - sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_); - }, - [=](const ck_tile::stream_config& s_) { fmha_sparge_fwd_oneshot(attn_t, attn_a, s_); }); + // Decide bucket plan. Pull per-head thresholds from device buffer when set, + // else broadcast the scalar across all heads to a single bucket. + const int nhead_q = attn_a.nhead_q; + std::vector false_heads; // pv_threshold >= 1e29f -> kNone binary (mode 0) + std::vector true_heads; // finite -> kPerWave or kPerBlock binary + false_heads.reserve(nhead_q); + true_heads.reserve(nhead_q); + + if(attn_a.pv_threshold_per_head_ptr != nullptr) + { + std::vector pv_host(nhead_q); + auto err = hipMemcpy(pv_host.data(), + attn_a.pv_threshold_per_head_ptr, + static_cast(nhead_q) * sizeof(float), + hipMemcpyDeviceToHost); + if(err != hipSuccess) + { + std::cerr << "sparge_sparge_fwd_combined: hipMemcpy pv_threshold_per_head failed: " + << hipGetErrorString(err) << std::endl; + return -1.f; + } + for(int h = 0; h < nhead_q; ++h) + { + if(pv_host[h] >= 1e29f) + false_heads.push_back(h); + else + true_heads.push_back(h); + } + } + else + { + // Scalar mode: identity remap, single binary picked by pv_mode_compile + // (R30) or the legacy pv_skip_compile bool (R25 A1). When the scalar + // pv_threshold is the sentinel, force the kNone binary regardless of + // mode_compile — the mode is then irrelevant because no skip happens. + if(attn_a.pv_threshold >= 1e29f) + for(int h = 0; h < nhead_q; ++h) + false_heads.push_back(h); + else + for(int h = 0; h < nhead_q; ++h) + true_heads.push_back(h); + } + + // R26-R3 gate: skip empty buckets so we never schedule a zero-grid launch. + const bool need_false = !false_heads.empty(); + const bool need_true = !true_heads.empty(); + + // Materialise per-bucket head-remap device buffers (one int32 each, freed at + // end of this function -- before that we keep them alive across the launch). + ck_tile::DeviceMem false_remap_dev(std::max(1, false_heads.size() * sizeof(int32_t))); + ck_tile::DeviceMem true_remap_dev(std::max(1, true_heads.size() * sizeof(int32_t))); + if(need_false) + false_remap_dev.ToDevice(false_heads.data()); + if(need_true) + true_remap_dev.ToDevice(true_heads.data()); + + // Build per-bucket attn args. Scalar pv_threshold field is left as-is so the + // device fallback (when pv_threshold_per_head is null and remap is null) + // remains correct; per-head buffer takes priority when remap is active. + fmha_sparge_fwd_args attn_false = attn_a; + fmha_sparge_fwd_args attn_true = attn_a; + // R30: derive the effective per-bucket mode. The "true" (finite-threshold) + // bucket inherits attn_a.pv_mode_compile so the CLI --pv_mode picks per-wave + // (1) or per-block (2). The "false" (sentinel) bucket is always mode 0 + // (kNone). If a caller still sets only the legacy pv_skip_compile bool + // (R25-A1-era) and leaves pv_mode_compile at its default 1, the behaviour + // is unchanged. + if(need_false) + { + attn_false.head_remap_ptr = static_cast(false_remap_dev.GetDeviceBuffer()); + attn_false.nhead_in_launch = static_cast(false_heads.size()); + attn_false.pv_skip_compile = false; // legacy bool — kept consistent + attn_false.pv_mode_compile = 0; // route to kNone binary (R30) + } + if(need_true) + { + attn_true.head_remap_ptr = static_cast(true_remap_dev.GetDeviceBuffer()); + attn_true.nhead_in_launch = static_cast(true_heads.size()); + attn_true.pv_skip_compile = true; // legacy bool — kept consistent + // R30: pv_mode_compile carries through unchanged from attn_a (CLI choice). + // attn_true is a copy of attn_a, so attn_true.pv_mode_compile already + // holds the user's selection (0 = kNone, 1 = per-wave, 2 = per-block). + // We deliberately do NOT override mode 0 here: if the user passes + // --pv_mode=none together with a finite pv_threshold, that is an + // explicit "build the bucket but don't skip" request (useful as a + // control measurement). Routing it to kNone keeps the CLI honest. + } + + // Chain callables: kstats -> blockmap -> [fmha_false?] -> [fmha_true?]. + // Empty buckets are skipped by emitting an empty lambda; the wrapped path + // never issues a kernel launch in that branch. + auto cb_kstats = [=](const ck_tile::stream_config& s_) { + sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); + }; + auto cb_bmap = [=](const ck_tile::stream_config& s_) { + sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_); + }; + auto cb_fmha_false = [=](const ck_tile::stream_config& s_) { + if(need_false) + fmha_sparge_fwd_oneshot(attn_t, attn_false, s_); + }; + auto cb_fmha_true = [=](const ck_tile::stream_config& s_) { + if(need_true) + fmha_sparge_fwd_oneshot(attn_t, attn_true, s_); + }; + + // launch_kernel returns elapsed ms for the whole chain when timing is on. + // We always pass 4 callables and gate execution inside the lambda; this + // keeps the timing contract stable, while a no-op lambda has negligible + // (~ns) cost compared to the saved 5-15us host launch. + return ck_tile::launch_kernel(s, cb_kstats, cb_bmap, cb_fmha_false, cb_fmha_true); } diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp index ae0952cc41..8ba1b97e84 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -110,11 +111,22 @@ auto create_args(int argc, char* argv[]) .insert("pv_threshold", "1e30", "SpargeAttn PV-skip per-Q-tile threshold; default +1e30 disables skip") + .insert("pv_threshold_per_head", + "", + "R26 split-launch: comma-separated per-head pv_threshold list " + "(length must == h). Empty = scalar mode using -pv_threshold.") .insert("pv_skip_compile", "1", "R25 V0: 1=use kEnablePVSkip=true template instance (existing path); 0=use " "kEnablePVSkip=false instance (PV-skip AST removed at compile time, equivalent to " - "VSA baseline)"); + "VSA baseline). Deprecated by -pv_mode; kept for back-compat scripts.") + .insert("pv_mode", + "warp", + "R30: PV-skip mode select. one of {none, warp, block}. " + "none = no skip (kNone binary; matches VSA baseline). " + "warp = per-wavefront butterfly vote (R25 A1; default). " + "block = per-block AND vote via 1 LDS slot + block_sync_lds (R30). " + "Overrides -pv_skip_compile when set explicitly."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -147,6 +159,31 @@ bool run_test(const ck_tile::ArgParser& arg_parser) std::string dump_o_path = arg_parser.get_str("dump_o"); float pv_threshold = arg_parser.get_float("pv_threshold"); int pv_skip_compile = arg_parser.get_int("pv_skip_compile"); + std::string pv_per_head_s = arg_parser.get_str("pv_threshold_per_head"); + std::string pv_mode_str = arg_parser.get_str("pv_mode"); + + // R30: --pv_mode maps to the int dispatched at host. + // none -> 0 (kNone), warp -> 1 (kPerWave), block -> 2 (kPerBlock). + // Back-compat: if the user explicitly passed -pv_skip_compile=0 but left + // -pv_mode at default ("warp"), honour the legacy intent (mode=0). The CLI + // doesn't expose "was this passed explicitly", so we mirror the rule used + // pre-R30: bool 0 => kNone, bool 1 => kPerWave. + int pv_mode_compile; + if(pv_mode_str == "none") + pv_mode_compile = 0; + else if(pv_mode_str == "warp") + pv_mode_compile = 1; + else if(pv_mode_str == "block") + pv_mode_compile = 2; + else + { + std::cerr << "Unknown -pv_mode value: '" << pv_mode_str + << "' (expected one of: none, warp, block)" << std::endl; + return false; + } + // Legacy bool wins iff user explicitly disabled and pv_mode stayed warp. + if(pv_skip_compile == 0 && pv_mode_str == "warp") + pv_mode_compile = 0; if(nhead_k < 0) nhead_k = nhead; @@ -271,6 +308,31 @@ bool run_test(const ck_tile::ArgParser& arg_parser) static_cast(cdf_per_head_dev.GetDeviceBuffer()); } + // R26 split-launch: optional per-head pv_threshold buffer. Parse the CLI + // comma list (length must match nhead); empty list -> scalar broadcast + // (legacy path, single launch via host). + ck_tile::DeviceMem pv_per_head_dev(static_cast(nhead) * sizeof(float)); + std::vector pv_per_head_host; + bool use_pv_per_head = false; + if(!pv_per_head_s.empty()) + { + std::stringstream ss(pv_per_head_s); + std::string item; + while(std::getline(ss, item, ',')) + { + if(!item.empty()) + pv_per_head_host.push_back(std::stof(item)); + } + if(static_cast(pv_per_head_host.size()) != nhead) + { + std::cerr << "\n[pv_threshold_per_head] length " << pv_per_head_host.size() + << " != h=" << nhead << std::endl; + return false; + } + pv_per_head_dev.ToDevice(pv_per_head_host.data()); + use_pv_per_head = true; + } + // ---- build attention args ---- ck_tile::stream_config stream_cfg; stream_cfg.stream_id_ = nullptr; @@ -354,21 +416,28 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_args.scale_s = scale_s; attn_args.pv_threshold = pv_threshold; attn_args.pv_skip_compile = (pv_skip_compile != 0); - attn_args.stride_q = q_strides[i_perm ? 2 : 1]; - attn_args.stride_k = k_strides[i_perm ? 2 : 1]; - attn_args.stride_v = v_strides[i_perm ? 2 : 1]; - attn_args.stride_o = o_strides[o_perm ? 2 : 1]; - attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; - attn_args.batch_stride_q = q_strides[0]; - attn_args.batch_stride_k = k_strides[0]; - attn_args.batch_stride_v = v_strides[0]; - attn_args.batch_stride_o = o_strides[0]; - attn_args.window_size_left = -1; - attn_args.window_size_right = -1; - attn_args.mask_type = 0; + attn_args.pv_mode_compile = pv_mode_compile; // R30: 0=none,1=warp,2=block + // R26 split-launch: when CLI provided per-head list, hand the device + // buffer to the combined wrapper; host code there will partition heads + // into 2 buckets and issue per-bucket launches. + attn_args.pv_threshold_per_head_ptr = + use_pv_per_head ? static_cast(pv_per_head_dev.GetDeviceBuffer()) + : nullptr; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = -1; + attn_args.window_size_right = -1; + attn_args.mask_type = 0; avg_ms = sparge_sparge_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp index d600ff7075..cbca128ca6 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp @@ -7,6 +7,9 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +// PVSkipMode enum lives in the sparge pipeline header; pull it in so the +// kernel template arg can name it (R30: promote bool kEnablePVSkip_ to 3-way enum). +#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp" #include #include @@ -21,7 +24,9 @@ namespace ck_tile { -template +template struct FmhaFwdSpargeKernel { using FmhaPipeline = ck_tile::remove_cvref_t; @@ -30,7 +35,9 @@ struct FmhaFwdSpargeKernel static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; - static constexpr bool kEnablePVSkip = kEnablePVSkip_; + static constexpr PVSkipMode kPVSkipMode = kPVSkipMode_; + // Legacy alias preserved: any non-kNone mode is "PV-skip enabled". + static constexpr bool kEnablePVSkip = (kPVSkipMode_ != PVSkipMode::kNone); using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; @@ -99,6 +106,15 @@ struct FmhaFwdSpargeKernel ck_tile::index_t nhead_ratio_qk; float scale_s; float pv_threshold; + // R26 split-launch: when non-null, indexed by remapped i_nhead (post head_remap), + // overrides scalar pv_threshold. Buffer length = num_head_q. + const float* pv_threshold_per_head; + // R26 split-launch: when non-null, i_nhead = head_remap_ptr[blockIdx.y]. + // Buffer length = nhead_in_launch. Null = identity (blockIdx.y directly). + const int* head_remap_ptr; + // R26 split-launch: gridDim.y when head_remap_ptr is active (== bucket size). + // Kept for future host-side asserts / debug; kernel reads via blockIdx.y. + ck_tile::index_t nhead_in_launch; ck_tile::index_t stride_q; ck_tile::index_t stride_k; @@ -165,7 +181,12 @@ struct FmhaFwdSpargeKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + // R26 split-launch (default-null preserves + // backward compat = scalar mode). + const float* pv_threshold_per_head = nullptr, + const int* head_remap_ptr = nullptr, + ck_tile::index_t nhead_in_launch = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -185,6 +206,9 @@ struct FmhaFwdSpargeKernel scale_s, #endif pv_threshold, + pv_threshold_per_head, + head_remap_ptr, + nhead_in_launch, stride_q, stride_k, stride_v, @@ -224,7 +248,18 @@ struct FmhaFwdSpargeKernel const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; + // R26 split-launch: if head_remap_ptr is set, translate the launch-local + // head index to the original num_head_q-space index. Null pointer -> + // identity (single-launch backward compat). The remap LUT load is uniform + // across the wavefront (same blockIdx.y for all lanes), but the compiler + // can't infer scalar-uniformity through a global ptr indirection, so we + // broadcast via readfirstlane. Without this, dependent offset/buffer- + // descriptor computations spill to VGPRs and buffer_load_dwordx4 inline + // asm rejects the VGPR operand. + const index_t i_nhead = + (kargs.head_remap_ptr != nullptr) + ? __builtin_amdgcn_readfirstlane(kargs.head_remap_ptr[blockIdx.y]) + : static_cast(blockIdx.y); const index_t i_batch = blockIdx.z; const auto f = [](index_t dividend, index_t divisor) { @@ -402,6 +437,23 @@ struct FmhaFwdSpargeKernel BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + // R26 split-launch: per-head pv_threshold override (null = scalar mode). + // i_nhead is already scalar-broadcast in GetTileIndex; the load is uniform + // and the resulting float lands in SGPRs naturally. We additionally route + // via readfirstlane on the int representation as a defensive hint to keep + // it scalar even when the compiler is conservative about float traffic. + float pv_threshold_resolved; + if(kargs.pv_threshold_per_head != nullptr) + { + const int raw = __builtin_amdgcn_readfirstlane( + __builtin_bit_cast(int, kargs.pv_threshold_per_head[i_nhead])); + pv_threshold_resolved = __builtin_bit_cast(float, raw); + } + else + { + pv_threshold_resolved = kargs.pv_threshold; + } + auto o_acc_tile = FmhaPipeline{}(q_dram_window, k_dram_window, v_dram_window, @@ -409,7 +461,7 @@ struct FmhaFwdSpargeKernel valid_block_num_value, mask, kargs.scale_s, - kargs.pv_threshold, + pv_threshold_resolved, variant, variant_params, block_indices, diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp index 4bf3c9d296..0a8baa4e62 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp @@ -11,18 +11,40 @@ namespace ck_tile { +// R30: PV-skip mode enum. R25 A1 shipped a per-wavefront vote; R30 adds a +// per-block consensus vote (matches upstream SpargeAttn kPerBlock semantics; +// see R29 researcher report per_block_vload_guard.md). kNone disables the +// skip path entirely (AST removed). The legacy bool `kEnablePVSkip_=true` +// maps to kPerWave; `false` maps to kNone — preserved via codegen. +enum class PVSkipMode : int +{ + kNone = 0, + kPerWave = 1, + kPerBlock = 2, +}; + // Sparge variant of qr/ks/vs/async pipeline. Cloned from BlockFmhaPipelineQRKSVSAsyncVSA; // adds PV-skip per Q-tile (SpargeAttn paper 4.4). Kept as a separate file so the original // _vsa.hpp can remain frozen as an A/B baseline. // +// R30: kPVSkipMode_ promoted from bool to 3-value enum {kNone, kPerWave, kPerBlock}. +// kPerWave is the R25 A1 shipped path; kPerBlock adds a block-wide consensus AND vote +// (1 LDS slot + 1 block_sync_lds) so all waves in a block agree before skipping the +// PV mma. Per R29 audit, the V load / V->LDS store / cp_async pipeline stay +// unconditional in BOTH per-wave and per-block modes (only the gemm_1 is gated). +// // QUANT-HOOK: future int8/sage variant will add QScaleEnum template arg + per-tile descale Kargs; // _sparge_sage.hpp will live alongside this file and reuse the PV-skip path verbatim. template + typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, + PVSkipMode kPVSkipMode_ = PVSkipMode::kPerWave> struct BlockFmhaPipelineQRKSVSAsyncSparge { - static constexpr bool kEnablePVSkip = kEnablePVSkip_; + static constexpr PVSkipMode kPVSkipMode = kPVSkipMode_; + // Legacy alias: true iff any PV-skip mode (per-wave or per-block) is active. + // Kept so existing `if constexpr (kEnablePVSkip)` reads still compile. + static constexpr bool kEnablePVSkip = (kPVSkipMode_ != PVSkipMode::kNone); + static constexpr bool kPerBlockPVSkip = (kPVSkipMode_ == PVSkipMode::kPerBlock); using Problem = remove_cvref_t; using Policy = remove_cvref_t; @@ -140,7 +162,22 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge static constexpr const char* name = "qr_async"; + // R30: per-block PV-skip needs one int32 LDS slot to broadcast the AND-vote + // result across waves. Reserved at the TAIL of the pipeline's LDS budget + // (after the existing K + V allocations), 4 bytes, aligned. When mode is + // kNone or kPerWave the byte is unused; the sentinel cost is negligible + // (4 bytes vs the multi-kB K/V tiles) so we always reserve it to keep the + // smem layout uniform across modes — simpler than per-mode policy plumbing. + static constexpr ck_tile::index_t kPerBlockVoteSlotBytes = 4; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize() + kPerBlockVoteSlotBytes; + } + + // R30: byte offset of the per-block vote flag from `smem_ptr`. Lives just + // past the policy's K+V smem footprint. + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetPerBlockVoteSlotOffset() { return Policy::template GetSmemSize(); } @@ -513,6 +550,69 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge }; const bool warp_skip = compute_warp_skip(); + // ================================================================ + // R30: per-block PV-skip — block-wide AND vote over warp_skip. + // Hand-rolled (no `block_and` primitive in CK-tile, no + // `__syncthreads_and` analog — see R30 idiom catalog §7.5). + // + // Protocol: + // 1. Lane 0 of each wave atomicAnd's its warp_skip int into a + // shared LDS sentinel (initialised to 1 by lane 0 of wave 0 + // before the vote). + // 2. block_sync_lds() — all stores visible, all waves rendezvous + // (uses the same s_waitcnt+s_barrier discipline as the K/V + // LDS chain; lgkmcnt accounting stays consistent — idiom + // §3.1 / §4.2). + // 3. All lanes read the sentinel back into a register. The + // result is wave-uniform (and effectively SGPR after + // readfirstlane) — used to gate gemm_1 at :607 / :665 below. + // + // Cost: 1 LDS init + 1 atomicAnd + 1 block_sync_lds + 1 LDS load. + // The vote slot lives at `smem_ptr + GetPerBlockVoteSlotOffset()`, + // 4 bytes past the policy K+V budget (see GetSmemSize override). + // No interaction with LdsSeq rotation slots. + // + // V load / V->LDS store / cp_async pipeline stay UNCONDITIONAL in + // both per-wave and per-block modes — matches upstream SpargeAttn + // (R29 audit) and CK-tile LDS-rotation discipline. + // ================================================================ + bool block_skip = false; + if constexpr(kPerBlockPVSkip) + { + // Carve a 4-byte uint32 slot at the LDS tail. The cast is safe: + // GetSmemSize() bumped the smem_ptr allocation by 4 bytes (see + // pipeline override above), so the slot is dedicated to this + // pipeline instance and never reused by K/V tiles. + auto* vote_slot = reinterpret_cast(static_cast(smem_ptr) + + GetPerBlockVoteSlotOffset()); + + const int lane_id = threadIdx.x % warpSize; + const int warp_id = threadIdx.x / warpSize; + + // Initialise the sentinel to 1 (skip-everything) before any + // wave votes. Only one thread does the init; the subsequent + // block_sync_lds() makes it visible to all waves. + if(warp_id == 0 && lane_id == 0) + { + *vote_slot = 1u; + } + block_sync_lds(); + + // Each wave contributes its warp_skip (already wave-uniform + // after the butterfly in compute_warp_skip). Lane 0 of each + // wave issues the atomicAnd; other lanes are idle. The atomic + // is on LDS (s_or_b32 / ds_and_b32), much cheaper than global. + if(lane_id == 0) + { + atomicAnd(vote_slot, warp_skip ? 1u : 0u); + } + block_sync_lds(); + + // Broadcast the consensus back to every lane. + const uint32_t consensus = *vote_slot; + block_skip = (consensus != 0u); + } + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { if constexpr(FmhaMask::IsMasking) { @@ -530,6 +630,10 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge // R25 redesign D: when kEnablePVSkip + warp_skip, we zero this // warp's owned rows of p_compute so the unconditional gemm_1 // contributes zero to o_acc, and skip the rowsum. + // R30: per-block mode uses block_skip (uniform across waves) and + // additionally skips gemm_1 itself (see guard at the gemm_1 site + // below). The p_compute zeroing remains so rowsum_p -> 0 and + // `l += rowsum_p` is a no-op for skipped iters. constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); @@ -538,7 +642,15 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); - if constexpr(kEnablePVSkip) + if constexpr(kPerBlockPVSkip) + { + if(block_skip) + { + p_compute(i_j_idx) = SMPLComputeDataType{0}; + return; + } + } + else if constexpr(kEnablePVSkip) { if(warp_skip) { @@ -603,15 +715,39 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge number<-1>{}, bool_constant{}); // load next v_buf } + // block_sync_lds() stays UNCONDITIONAL — it is the + // workgroup barrier the V->LDS rotation chain requires + // (idiom catalog §3.1 / §4.1). Only the gemm_1 MFMA is + // gated on block_skip when in per-block mode. block_sync_lds(); - gemm_1( - o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - get_slice_tile( - v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + if constexpr(kPerBlockPVSkip) + { + if(!block_skip) + { + gemm_1( + o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1>{}, + sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } + } + else + { + gemm_1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1>{}, + sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } if constexpr(std::is_same_v) @@ -659,16 +795,37 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge k_pre_np); move_tile_window(k_dram_window, {0, kK0}); } - // tail — gemm_1 runs unconditionally under redesign D. + // tail — gemm_1 runs unconditionally under redesign D (per-wave). + // R30: per-block mode gates the MFMA on block_skip; block_sync_lds + // still runs unconditionally (workgroup barrier for LDS rotation). { block_sync_lds(); - gemm_1( - o_acc, - get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), - get_slice_tile( - v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + if constexpr(kPerBlockPVSkip) + { + if(!block_skip) + { + gemm_1( + o_acc, + get_slice_tile( + p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } + } + else + { + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } } } while(i_total_loops < num_total_loop);