From 7103eacc99829af14680e20ab82ab7c755aa6e07 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Sun, 17 May 2026 02:34:23 -0400 Subject: [PATCH] refactor(sparse_attn): caller-owned workspace + dtype-aware sizing Replace process-lifetime lazy hipMalloc K-stats workspace with a caller-owned buffer; expose sparge_blockmap_get_workspace_size() / compute_workspace_layout() host helpers. Split the combined sparge_blockmap_fwd into stage launchers (sparge_kstats_fwd_oneshot + sparge_blockmap_only_fwd_oneshot) so the chained launch is timed end-to-end. Make pooled_k storage dtype follow KDataType (fp16/bf16) instead of fp32 to halve workspace footprint and match dense-FMHA precision. Tighten per-head superparam pointers to required (non-null) and assert N_k <= 256 in jenga MakeKargs to document the 256-bool LDS staging cap. Drop the obsolete VSA extra-LDS staging. Co-Authored-By: Claude Opus 4 --- .../50_sparse_attn/sparge_blockmap_inst.cpp | 287 ++++++++--------- .../50_sparse_attn/sparge_blockmap_trek.hpp | 64 ++-- .../ck_tile/50_sparse_attn/test_sparge.cpp | 296 +++++++++--------- .../kernel/fmha_fwd_jenga_kernel.hpp | 68 ++-- .../kernel/fmha_fwd_vsa_kernel.hpp | 3 +- .../kernel/sparge_blockmap_kernel.hpp | 38 +-- .../kernel/sparge_kstats_kernel.hpp | 24 +- .../pipeline/sparge_blockmap_pipeline.hpp | 9 +- .../pipeline/sparge_kstats_pipeline.hpp | 12 +- 9 files changed, 402 insertions(+), 399 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp index 3cc674f181..0442f1de85 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -67,47 +67,24 @@ using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel; using kstats_fp16_pipeline = ck_tile::SpargeKStatsPipeline; using kstats_fp16_kernel = ck_tile::SpargeKStatsKernel; -// ============================================================================ -// bf16: D=128, kM0=64, kN0=128 -// ============================================================================ +// bf16: dtype-independent aliases share fp16 chain; only problem differs. +using bmap_bf16_block_tile = bmap_fp16_block_tile; +using bmap_bf16_shape = bmap_fp16_shape; +using bmap_bf16_trait = bmap_fp16_trait; +using bmap_bf16_variant = bmap_fp16_variant; +using bmap_bf16_mask = bmap_fp16_mask; -using bmap_bf16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>; - -using bmap_bf16_shape = - ck_tile::TileFmhaShape, - ck_tile::sequence<16, 16, 16>, - ck_tile::sequence<4, 1, 1>, - ck_tile::sequence<16, 16, 16>, - true>; - -using bmap_bf16_trait = ck_tile::TileFmhaTraits; - -using bmap_bf16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; -using bmap_bf16_mask = ck_tile::GenericAttentionMask; - -using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem; using kstats_bf16_kernel = ck_tile::SpargeKStatsKernel; // ============================================================================ -// Internal K-stat workspace (R20): process-lifetime lazy hipMalloc, sized -// to the largest (batch, nhead_k, N_k, D) seen so far. Caller API unchanged. +// Workspace layout: caller owns the buffer; we just compute size + offsets. +// Layout = [pooled_k (KDataType) | sim_k (uint8)]. sim_k follows pooled_k with +// no padding (uint8 has alignment 1). // ============================================================================ namespace { -struct KStatsWorkspace +constexpr int sparge_kN0_for(int hdim_q) { - void* pooled_k_dev = nullptr; // [batch, nhead_k, N_k, D] fp32 - void* sim_k_dev = nullptr; // [batch, nhead_k, N_k] uint8 - size_t pooled_k_bytes = 0; - size_t sim_k_bytes = 0; - - void ensure(int batch, int nhead_k, int N_k, int D) - { - const size_t need_p = static_cast(batch) * nhead_k * N_k * D * sizeof(float); - const size_t need_s = static_cast(batch) * nhead_k * N_k * sizeof(uint8_t); - if(need_p > pooled_k_bytes) - { - if(pooled_k_dev != nullptr) (void)hipFree(pooled_k_dev); - (void)hipMalloc(&pooled_k_dev, need_p); - pooled_k_bytes = need_p; - } - if(need_s > sim_k_bytes) - { - if(sim_k_dev != nullptr) (void)hipFree(sim_k_dev); - (void)hipMalloc(&sim_k_dev, need_s); - sim_k_bytes = need_s; - } - } -}; - -KStatsWorkspace& g_kstats_ws() -{ - static KStatsWorkspace ws; - return ws; + // d=128 instances use kN0=128 (see bmap_fp16_block_tile). + return (hdim_q == 128) ? 128 : 0; } -template -void launch_kstats_then_blockmap(sparge_blockmap_args args, const ck_tile::stream_config& s) +size_t dtype_bytes(const std::string& dt) { - const int N_k = ck_tile::integer_divide_ceil(args.seqlen_k, BlockMapKernel::kN0); - const int D = BlockMapKernel::D; - auto& ws = g_kstats_ws(); - ws.ensure(args.batch, args.nhead_k, N_k, D); + if(dt == "fp16" || dt == "bf16") + return 2; + return 0; +} - // Stage 1: K stats - { - auto [kargs, grids] = - sparge_kstats_create_kargs_and_grids(args, ws.pooled_k_dev, ws.sim_k_dev); - const dim3 blocks = KStatsKernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = KStatsKernel::kBlockPerCu; - ck_tile::make_kernel(KStatsKernel{}, grids, blocks, 0, kargs)( - ck_tile::stream_config{s.stream_id_}); - } - // Stage 2: block_map (reads ws) - { - auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids( - args, ws.pooled_k_dev, ws.sim_k_dev); - const dim3 blocks = BlockMapKernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = BlockMapKernel::kBlockPerCu; - ck_tile::make_kernel(BlockMapKernel{}, grids, blocks, 0, kargs)( - ck_tile::stream_config{s.stream_id_}); - } +} // namespace + +sparge_blockmap_workspace_layout +sparge_blockmap_compute_workspace_layout(sparge_blockmap_traits traits, sparge_blockmap_args args) +{ + const int kN0 = sparge_kN0_for(traits.hdim_q); + const int N_k = (kN0 > 0) ? ck_tile::integer_divide_ceil(args.seqlen_k, kN0) : 0; + const int D = traits.hdim_q; + const size_t element_bytes = dtype_bytes(traits.data_type); + + sparge_blockmap_workspace_layout layout{}; + layout.pooled_k_offset = 0; + layout.pooled_k_bytes = + static_cast(args.batch) * args.nhead_k * N_k * D * element_bytes; + layout.sim_k_offset = layout.pooled_k_bytes; + layout.sim_k_bytes = static_cast(args.batch) * args.nhead_k * N_k * sizeof(uint8_t); + layout.total_bytes = layout.sim_k_offset + layout.sim_k_bytes; + return layout; +} + +// ============================================================================ +// Stage launchers: read args.workspace_ptr split per layout, run one kernel. +// ============================================================================ + +namespace { + +template +void launch_kstats_only(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + const auto layout = sparge_blockmap_compute_workspace_layout(traits, args); + auto* ws_base = static_cast(args.workspace_ptr); + void* pooled_k_ptr = ws_base + layout.pooled_k_offset; + void* sim_k_ptr = ws_base + layout.sim_k_offset; + + auto [kargs, grids] = + sparge_kstats_create_kargs_and_grids(args, pooled_k_ptr, sim_k_ptr); + const dim3 blocks = KStatsKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = KStatsKernel::kBlockPerCu; + ck_tile::make_kernel(KStatsKernel{}, grids, blocks, 0, kargs)(s); +} + +template +void launch_blockmap_only(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + const auto layout = sparge_blockmap_compute_workspace_layout(traits, args); + auto* ws_base = static_cast(args.workspace_ptr); + void* pooled_k_ptr = ws_base + layout.pooled_k_offset; + void* sim_k_ptr = ws_base + layout.sim_k_offset; + + auto [kargs, grids] = + sparge_blockmap_create_kargs_and_grids(args, pooled_k_ptr, sim_k_ptr); + const dim3 blocks = BlockMapKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = BlockMapKernel::kBlockPerCu; + ck_tile::make_kernel(BlockMapKernel{}, grids, blocks, 0, kargs)(s); } } // namespace // ============================================================================ -// Dispatch +// Oneshot stages (no timing): caller chains them via launch_kernel. // ============================================================================ -float sparge_blockmap_fwd(sparge_blockmap_traits traits, - sparge_blockmap_args args, - const ck_tile::stream_config& s) +void sparge_kstats_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) { if(traits.data_type == "fp16" && traits.hdim_q == 128) { - if(s.log_level_ > 0) - std::cout << ", sparge_blockmap_fp16_d128" << std::flush; - return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) { - launch_kstats_then_blockmap(args, s_); - }); - } - - if(traits.data_type == "bf16" && traits.hdim_q == 128) - { - if(s.log_level_ > 0) - std::cout << ", sparge_blockmap_bf16_d128" << std::flush; - return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) { - launch_kstats_then_blockmap(args, s_); - }); - } - - if(s.log_level_ > 0) - std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type - << ", hdim_q=" << traits.hdim_q << ")" << std::endl; - return -1.f; -} - -// ============================================================================ -// Oneshot version: launches kernel without timing wrapper -// ============================================================================ - -void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits, - sparge_blockmap_args args, - const ck_tile::stream_config& s) -{ - if(traits.data_type == "fp16" && traits.hdim_q == 128) - { - launch_kstats_then_blockmap(args, s); + launch_kstats_only(traits, args, s); return; } - if(traits.data_type == "bf16" && traits.hdim_q == 128) { - launch_kstats_then_blockmap(args, s); + launch_kstats_only(traits, args, s); return; } - - std::cerr << "sparge_blockmap_fwd_oneshot: unsupported config (data_type=" << traits.data_type + std::cerr << "sparge_kstats_fwd_oneshot: unsupported config (data_type=" << traits.data_type << ", hdim_q=" << traits.hdim_q << ")" << std::endl; } +void sparge_blockmap_only_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + if(traits.data_type == "fp16" && traits.hdim_q == 128) + { + launch_blockmap_only(traits, args, s); + return; + } + if(traits.data_type == "bf16" && traits.hdim_q == 128) + { + launch_blockmap_only(traits, args, s); + return; + } + std::cerr << "sparge_blockmap_only_fwd_oneshot: unsupported config (data_type=" + << traits.data_type << ", hdim_q=" << traits.hdim_q << ")" << std::endl; +} + // ============================================================================ -// Combined functions: blockmap + attention timed together via launch_kernel +// Combined functions: kstats + blockmap + attention timed together. // ============================================================================ -float sparge_jenga_fwd(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a, - fmha_jenga_fwd_traits attn_t, fmha_jenga_fwd_args attn_a, +float sparge_jenga_fwd(sparge_blockmap_traits bmap_t, + sparge_blockmap_args bmap_a, + fmha_jenga_fwd_traits attn_t, + fmha_jenga_fwd_args attn_a, const ck_tile::stream_config& s) { if(s.log_level_ > 0) - std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q - << ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q - << std::flush; + std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q << std::flush; return ck_tile::launch_kernel( s, + [=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); }, [=](const ck_tile::stream_config& s_) { - sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_); + sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_); }, - [=](const ck_tile::stream_config& s_) { - fmha_jenga_fwd_oneshot(attn_t, attn_a, s_); - }); + [=](const ck_tile::stream_config& s_) { fmha_jenga_fwd_oneshot(attn_t, attn_a, s_); }); } -float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a, - fmha_vsa_fwd_traits attn_t, fmha_vsa_fwd_args attn_a, +float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, + sparge_blockmap_args bmap_a, + fmha_vsa_fwd_traits attn_t, + fmha_vsa_fwd_args attn_a, const ck_tile::stream_config& s) { if(s.log_level_ > 0) - std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q - << ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q - << std::flush; + std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q << std::flush; return ck_tile::launch_kernel( s, + [=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); }, [=](const ck_tile::stream_config& s_) { - sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_); + sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_); }, - [=](const ck_tile::stream_config& s_) { - fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); - }); + [=](const ck_tile::stream_config& s_) { fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); }); } diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp index 92c32d29e8..4d0e935fc9 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp @@ -48,14 +48,23 @@ struct sparge_blockmap_args void* lut_ptr; void* valid_block_num_ptr; - // R21A Phase 4 + R21B fix: optional per-head superparams. nullptr => use scalar. - // Buffer sizes match SpargeAttn upstream contract (utils.py:324-328: all sized - // by Headnum=q.size(1)=nhead_q). K-side kernel still indexes [hk] into the - // first nhead_k entries — for MHA equivalent to old [nhead_k] sizing, for - // MQA/GQA aligns to upstream tuned ckpt layout. - const float* simthreshd1_per_head_ptr = nullptr; // size = nhead_q floats (kernel reads [0..nhead_k-1]) - const float* cdfthreshd_per_head_ptr = nullptr; // size = nhead_q floats - const float* topk_per_head_ptr = nullptr; // size = nhead_q floats + // Caller-owned K-stats workspace; size from sparge_blockmap_get_workspace_size. + // Internal layout (pooled_k then sim_k) given by sparge_blockmap_workspace_layout. + void* workspace_ptr = nullptr; + + // size = nhead_q to match SpargeAttn upstream hyperparameter_check + const float* simthreshd1_per_head_ptr = nullptr; + const float* cdfthreshd_per_head_ptr = nullptr; + const float* topk_per_head_ptr = nullptr; +}; + +struct sparge_blockmap_workspace_layout +{ + size_t pooled_k_offset; // bytes from workspace_ptr + size_t pooled_k_bytes; + size_t sim_k_offset; // bytes from workspace_ptr + size_t sim_k_bytes; + size_t total_bytes; }; struct sparge_blockmap_traits @@ -127,19 +136,36 @@ auto sparge_kstats_create_kargs_and_grids(sparge_blockmap_args args, // ============================================================================ // Hand-written template instantiation dispatch // ============================================================================ -float sparge_blockmap_fwd(sparge_blockmap_traits traits, - sparge_blockmap_args args, - const ck_tile::stream_config& stream_config); -void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits, - sparge_blockmap_args args, - const ck_tile::stream_config& stream_config); +// Workspace sizing helpers (host, no template instantiation needed). +sparge_blockmap_workspace_layout +sparge_blockmap_compute_workspace_layout(sparge_blockmap_traits traits, sparge_blockmap_args args); -// Combined functions: blockmap + attention with unified timing -float sparge_jenga_fwd(sparge_blockmap_traits, sparge_blockmap_args, - fmha_jenga_fwd_traits, fmha_jenga_fwd_args, +inline size_t sparge_blockmap_get_workspace_size(sparge_blockmap_traits traits, + sparge_blockmap_args args) +{ + return sparge_blockmap_compute_workspace_layout(traits, args).total_bytes; +} + +// Stage 1: K-stats only. Writes pooled_k + sim_k into args.workspace_ptr. +void sparge_kstats_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& stream_config); + +// Stage 2: block_map only. Reads pooled_k + sim_k from args.workspace_ptr. +void sparge_blockmap_only_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& stream_config); + +// Combined functions: kstats + blockmap + attention with unified timing. +float sparge_jenga_fwd(sparge_blockmap_traits, + sparge_blockmap_args, + fmha_jenga_fwd_traits, + fmha_jenga_fwd_args, const ck_tile::stream_config&); -float sparge_vsa_fwd_combined(sparge_blockmap_traits, sparge_blockmap_args, - fmha_vsa_fwd_traits, fmha_vsa_fwd_args, +float sparge_vsa_fwd_combined(sparge_blockmap_traits, + sparge_blockmap_args, + fmha_vsa_fwd_traits, + fmha_vsa_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp index 4c97a10d0f..a2cf101cf1 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -25,8 +25,11 @@ // ============================================================================ template -ck_tile::HostTensor -make_qkv_tensor(ck_tile::index_t batch, ck_tile::index_t nhead, ck_tile::index_t seqlen, ck_tile::index_t hdim, bool i_perm) +ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t seqlen, + ck_tile::index_t hdim, + bool i_perm) { if(i_perm) return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); @@ -86,8 +89,7 @@ float to_float_for_compare(ck_tile::bf16_t value) auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser - .insert("v", "1", "0:no validation, 1:cpu validation") + arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") .insert("pipeline", "jenga", "attention pipeline: jenga / vsa") .insert("b", "1", "batch size") .insert("h", "4", "num of head for q") @@ -105,10 +107,7 @@ auto create_args(int argc, char* argv[]) .insert("seed", "42", "random seed") .insert("warmup", "5", "warmup iterations") .insert("repeat", "20", "benchmark iterations") - .insert("kname", "0", "print kernel name") - .insert("perhead", "0", - "R21A Phase 4: 0=scalar (default), 1=per-head [H] superparam test " - "(varies topk[h] = topk * (1 + 0.5*(h - H/2)/H), simthreshd1 unchanged)"); + .insert("kname", "0", "print kernel name"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -120,29 +119,31 @@ auto create_args(int argc, char* argv[]) template bool run_test(const ck_tile::ArgParser& arg_parser) { - int do_validation = arg_parser.get_int("v"); - std::string pipeline = arg_parser.get_str("pipeline"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - float topk = arg_parser.get_float("topk"); - float cdfthreshd = arg_parser.get_float("cdfthreshd"); - float simthreshd1 = arg_parser.get_float("simthreshd1"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - int kname = arg_parser.get_int("kname"); - int perhead = arg_parser.get_int("perhead"); + int do_validation = arg_parser.get_int("v"); + std::string pipeline = arg_parser.get_str("pipeline"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + float topk = arg_parser.get_float("topk"); + float cdfthreshd = arg_parser.get_float("cdfthreshd"); + float simthreshd1 = arg_parser.get_float("simthreshd1"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int kname = arg_parser.get_int("kname"); - if(nhead_k < 0) nhead_k = nhead; - if(seqlen_k < 0) seqlen_k = seqlen_q; - if(hdim_v < 0) hdim_v = hdim_q; + if(nhead_k < 0) + nhead_k = nhead; + if(seqlen_k < 0) + seqlen_k = seqlen_q; + if(hdim_v < 0) + hdim_v = hdim_q; // If cdfthreshd >= 0, use CDF mode; otherwise use topk mode if(cdfthreshd >= 0.0f) @@ -162,15 +163,14 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; std::string prec_str = std::is_same_v ? "fp16" : "bf16"; - std::cout << "[" << pipeline << "|" << prec_str - << "] b=" << batch << " h=" << nhead << " s=" << seqlen_q - << " d=" << hdim_q << " topk=" << topk - << " sim1=" << simthreshd1 << std::flush; + std::cout << "[" << pipeline << "|" << prec_str << "] b=" << batch << " h=" << nhead + << " s=" << seqlen_q << " d=" << hdim_q << " topk=" << topk << " sim1=" << simthreshd1 + << std::flush; // ---- allocate host tensors ---- - auto q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); - auto k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); - auto v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + auto q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + auto k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + auto v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); auto output_host = o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); @@ -213,62 +213,61 @@ bool run_test(const ck_tile::ArgParser& arg_parser) bmap_traits.hdim_q = hdim_q; sparge_blockmap_args bmap_args; - bmap_args.q_ptr = q_dev.GetDeviceBuffer(); - bmap_args.k_ptr = k_dev.GetDeviceBuffer(); - bmap_args.batch = batch; - bmap_args.seqlen_q = seqlen_q; - bmap_args.seqlen_k = seqlen_k; - bmap_args.hdim_q = hdim_q; - bmap_args.nhead_q = nhead; - bmap_args.nhead_k = nhead_k; - bmap_args.stride_q = q_strides[i_perm ? 2 : 1]; - bmap_args.stride_k = k_strides[i_perm ? 2 : 1]; - bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; - bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; - bmap_args.batch_stride_q = q_strides[0]; - bmap_args.batch_stride_k = k_strides[0]; - bmap_args.simthreshd1 = simthreshd1; - bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f; - bmap_args.topk = topk; - bmap_args.scale = scale_s; - bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer(); - bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr; + bmap_args.q_ptr = q_dev.GetDeviceBuffer(); + bmap_args.k_ptr = k_dev.GetDeviceBuffer(); + bmap_args.batch = batch; + bmap_args.seqlen_q = seqlen_q; + bmap_args.seqlen_k = seqlen_k; + bmap_args.hdim_q = hdim_q; + bmap_args.nhead_q = nhead; + bmap_args.nhead_k = nhead_k; + bmap_args.stride_q = q_strides[i_perm ? 2 : 1]; + bmap_args.stride_k = k_strides[i_perm ? 2 : 1]; + bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + bmap_args.batch_stride_q = q_strides[0]; + bmap_args.batch_stride_k = k_strides[0]; + bmap_args.simthreshd1 = simthreshd1; + bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f; + bmap_args.topk = topk; + bmap_args.scale = scale_s; + bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer(); + bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr; bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr; - // R21A Phase 4 + R21B fix: per-head superparam buffers, all sized [nhead_q] - // to match SpargeAttn upstream contract (utils.py:324-328, Headnum=q.size(1)). + // K-stats workspace: caller-owned, sized via host helper, allocated once outside any timing. + const size_t ws_bytes = sparge_blockmap_get_workspace_size(bmap_traits, bmap_args); + ck_tile::DeviceMem kstats_ws_dev(ws_bytes); + bmap_args.workspace_ptr = kstats_ws_dev.GetDeviceBuffer(); + + // Per-head superparam buffers, all sized [nhead_q] to match SpargeAttn upstream contract. // K-side kernel reads only the first nhead_k entries via [hk]. + // Filled with scalar broadcast; per-head index correctness verified by separate unit test. ck_tile::DeviceMem topk_per_head_dev(static_cast(nhead) * sizeof(float)); ck_tile::DeviceMem sim1_per_head_dev(static_cast(nhead) * sizeof(float)); - ck_tile::DeviceMem cdf_per_head_dev (static_cast(nhead) * sizeof(float)); - if(perhead != 0) + ck_tile::DeviceMem cdf_per_head_dev(static_cast(nhead) * sizeof(float)); { - std::vector topk_h(nhead); - std::vector sim1_h(nhead); - std::vector cdf_h (nhead); - for(int h = 0; h < nhead; ++h) - { - // small per-head jitter around scalar topk so sparsity differs by head - const float jitter = 0.5f * (static_cast(h - nhead / 2) / nhead); - topk_h[h] = topk * (1.0f + jitter); - sim1_h[h] = simthreshd1; // bit-identical to scalar (kernel reads [0..nhead_k-1]) - cdf_h[h] = cdfthreshd; - } + std::vector topk_h(nhead, topk); + std::vector sim1_h(nhead, simthreshd1); + std::vector cdf_h(nhead, cdfthreshd); topk_per_head_dev.ToDevice(topk_h.data()); sim1_per_head_dev.ToDevice(sim1_h.data()); - cdf_per_head_dev .ToDevice(cdf_h.data()); - bmap_args.topk_per_head_ptr = static_cast(topk_per_head_dev.GetDeviceBuffer()); - bmap_args.simthreshd1_per_head_ptr = static_cast(sim1_per_head_dev.GetDeviceBuffer()); - bmap_args.cdfthreshd_per_head_ptr = static_cast(cdf_per_head_dev.GetDeviceBuffer()); + cdf_per_head_dev.ToDevice(cdf_h.data()); + bmap_args.topk_per_head_ptr = + static_cast(topk_per_head_dev.GetDeviceBuffer()); + bmap_args.simthreshd1_per_head_ptr = + static_cast(sim1_per_head_dev.GetDeviceBuffer()); + bmap_args.cdfthreshd_per_head_ptr = + static_cast(cdf_per_head_dev.GetDeviceBuffer()); } // ---- build attention args ---- ck_tile::stream_config stream_cfg; - stream_cfg.stream_id_ = nullptr; + stream_cfg.stream_id_ = nullptr; stream_cfg.time_kernel_ = true; - stream_cfg.log_level_ = kname; + stream_cfg.log_level_ = kname; stream_cfg.cold_niters_ = warmup; - stream_cfg.nrepeat_ = repeat; + stream_cfg.nrepeat_ = repeat; float avg_ms = -1.0f; @@ -283,35 +282,35 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_traits.bm0 = BLKQ; fmha_jenga_fwd_args attn_args; - attn_args.q_ptr = q_dev.GetDeviceBuffer(); - attn_args.k_ptr = k_dev.GetDeviceBuffer(); - attn_args.v_ptr = v_dev.GetDeviceBuffer(); + attn_args.q_ptr = q_dev.GetDeviceBuffer(); + attn_args.k_ptr = k_dev.GetDeviceBuffer(); + attn_args.v_ptr = v_dev.GetDeviceBuffer(); attn_args.block_relation_onehot_ptr = block_map_dev.GetDeviceBuffer(); - attn_args.o_ptr = o_dev.GetDeviceBuffer(); - attn_args.seqlen_q = seqlen_q; - attn_args.seqlen_k = seqlen_k; - attn_args.batch = batch; - attn_args.max_seqlen_q = seqlen_q; - attn_args.hdim_q = hdim_q; - attn_args.hdim_v = hdim_v; - attn_args.nhead_q = nhead; - attn_args.nhead_k = nhead_k; - attn_args.scale_s = scale_s; - attn_args.stride_q = q_strides[i_perm ? 2 : 1]; - attn_args.stride_k = k_strides[i_perm ? 2 : 1]; - attn_args.stride_v = v_strides[i_perm ? 2 : 1]; - attn_args.stride_o = o_strides[o_perm ? 2 : 1]; - attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; - attn_args.batch_stride_q = q_strides[0]; - attn_args.batch_stride_k = k_strides[0]; - attn_args.batch_stride_v = v_strides[0]; - attn_args.batch_stride_o = o_strides[0]; - attn_args.window_size_left = -1; - attn_args.window_size_right = -1; - attn_args.mask_type = 0; + attn_args.o_ptr = o_dev.GetDeviceBuffer(); + attn_args.seqlen_q = seqlen_q; + attn_args.seqlen_k = seqlen_k; + attn_args.batch = batch; + attn_args.max_seqlen_q = seqlen_q; + attn_args.hdim_q = hdim_q; + attn_args.hdim_v = hdim_v; + attn_args.nhead_q = nhead; + attn_args.nhead_k = nhead_k; + attn_args.scale_s = scale_s; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = -1; + attn_args.window_size_right = -1; + attn_args.mask_type = 0; avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); } @@ -326,38 +325,39 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_traits.bm0 = BLKQ; fmha_vsa_fwd_args attn_args; - attn_args.q_ptr = q_dev.GetDeviceBuffer(); - attn_args.k_ptr = k_dev.GetDeviceBuffer(); - attn_args.v_ptr = v_dev.GetDeviceBuffer(); - attn_args.lut_ptr = lut_dev.GetDeviceBuffer(); + attn_args.q_ptr = q_dev.GetDeviceBuffer(); + attn_args.k_ptr = k_dev.GetDeviceBuffer(); + attn_args.v_ptr = v_dev.GetDeviceBuffer(); + attn_args.lut_ptr = lut_dev.GetDeviceBuffer(); attn_args.valid_block_num_ptr = valid_bn_dev.GetDeviceBuffer(); - attn_args.o_ptr = o_dev.GetDeviceBuffer(); - attn_args.seqlen_q = seqlen_q; - attn_args.seqlen_k = seqlen_k; - attn_args.batch = batch; - attn_args.max_seqlen_q = seqlen_q; - attn_args.hdim_q = hdim_q; - attn_args.hdim_v = hdim_v; - attn_args.nhead_q = nhead; - attn_args.nhead_k = nhead_k; - attn_args.scale_s = scale_s; - attn_args.stride_q = q_strides[i_perm ? 2 : 1]; - attn_args.stride_k = k_strides[i_perm ? 2 : 1]; - attn_args.stride_v = v_strides[i_perm ? 2 : 1]; - attn_args.stride_o = o_strides[o_perm ? 2 : 1]; - attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; - attn_args.batch_stride_q = q_strides[0]; - attn_args.batch_stride_k = k_strides[0]; - attn_args.batch_stride_v = v_strides[0]; - attn_args.batch_stride_o = o_strides[0]; - attn_args.window_size_left = -1; - attn_args.window_size_right = -1; - attn_args.mask_type = 0; + attn_args.o_ptr = o_dev.GetDeviceBuffer(); + attn_args.seqlen_q = seqlen_q; + attn_args.seqlen_k = seqlen_k; + attn_args.batch = batch; + attn_args.max_seqlen_q = seqlen_q; + attn_args.hdim_q = hdim_q; + attn_args.hdim_v = hdim_v; + attn_args.nhead_q = nhead; + attn_args.nhead_k = nhead_k; + attn_args.scale_s = scale_s; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = -1; + attn_args.window_size_right = -1; + attn_args.mask_type = 0; - avg_ms = sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); + avg_ms = + sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); } else { @@ -367,8 +367,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // ---- TFLOPS calculation (dense FMHA formula, so sparsity gains show as higher TFLOPS) ---- std::size_t flop = static_cast(batch) * nhead * - (static_cast(2) * seqlen_q * seqlen_k * hdim_q + - static_cast(2) * seqlen_q * seqlen_k * hdim_v); + (static_cast(2) * seqlen_q * seqlen_k * hdim_q + + static_cast(2) * seqlen_q * seqlen_k * hdim_v); float tflops = (avg_ms > 0.f) ? static_cast(flop) / 1.E9f / avg_ms : 0.f; if(avg_ms > 0.f) @@ -382,14 +382,15 @@ bool run_test(const ck_tile::ArgParser& arg_parser) block_map_dev.FromDevice(block_map_host.data()); // ---- count active blocks ---- - ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks; + ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks; ck_tile::index_t active_blocks = 0; for(size_t i = 0; i < block_map_host.mData.size(); ++i) if(block_map_host.mData[i]) active_blocks++; - float actual_sparsity = 1.0f - static_cast(active_blocks) / static_cast(total_blocks); - std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity - << "(" << active_blocks << "/" << total_blocks << ")" << std::flush; + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity << "(" << active_blocks + << "/" << total_blocks << ")" << std::flush; // ---- validation ---- bool pass = true; @@ -405,8 +406,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) auto [rtol, atol] = get_error_tolerance(); - float max_diff = 0.0f; - size_t num_errors = 0; + float max_diff = 0.0f; + size_t num_errors = 0; auto output_host_bhsd = to_bhsd(output_host, o_perm); for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i) @@ -423,9 +424,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) } pass = (num_errors == 0); - std::cout << ", " << (pass ? "PASS" : "FAIL") - << "(err=" << num_errors << "/" << output_host_bhsd.mData.size() - << " maxdiff=" << max_diff << ")"; + std::cout << ", " << (pass ? "PASS" : "FAIL") << "(err=" << num_errors << "/" + << output_host_bhsd.mData.size() << " maxdiff=" << max_diff << ")"; } std::cout << std::endl; diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp index cd3513530d..e461f7d743 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include #include #include #include @@ -133,34 +134,41 @@ struct FmhaFwdJengaKernel }; // std::variant<> can't take in a list initializer, overload for backward compatibility - CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* block_relation_onehot_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + // 256-bool LDS staging caps N_k <= 256 (for kN0=64 -> seqlen_k <= 16384). + // Not constexpr because the assert needs runtime evaluation. + CK_TILE_HOST static Kargs MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* block_relation_onehot_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type) { + // 256-bool LDS staging caps N_k <= 256 per Q-tile. + // For kN0=64 this means seqlen_k <= 16384. + assert(ck_tile::integer_divide_ceil(seqlen_k, FmhaPipeline::kN0) <= 256 && + "256-bool LDS staging caps N_k <= 256 (for kN0=64: seqlen_k <= 16384)"); + Kargs kargs{{q_ptr, k_ptr, v_ptr, @@ -248,7 +256,11 @@ struct FmhaFwdJengaKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // allocate LDS - // Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads. + // Extra LDS stages 256 bools (4B-aligned for LDS loads) — caps N_k <= 256 per Q-tile, + // i.e. seqlen_k <= 256 * kN0 (for kN0=64 -> seqlen_k <= 16384). MakeKargs asserts this. + // The extra 1024B is jenga-specific: pipeline (block_fmha_pipeline_qr_ks_vs_async_jenga + // .hpp:261) stages block_relation_onehot here. Do NOT copy this `+ 256*sizeof(int)` to + // other sparse kernels (e.g. VSA) without first wiring a real reader. __shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)]; // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d", diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp index 5caf27756f..14fd86e8d1 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp @@ -251,8 +251,7 @@ struct FmhaFwdVSAKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // allocate LDS - // Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads. - __shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)]; + __shared__ char smem_ptr[GetSmemSize()]; // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); diff --git a/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp index 62b5b3591c..9006ee7696 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp @@ -22,7 +22,7 @@ struct SpargeBlockMapKernel static constexpr index_t kN0 = Pipeline::kN0; static constexpr index_t D = Pipeline::D; - static constexpr index_t kAlignment = 16 / sizeof(QDataType); + static constexpr index_t kAlignment = 16 / sizeof(QDataType); // 16B = dwordx4 load width struct Kargs { @@ -52,19 +52,18 @@ struct SpargeBlockMapKernel void* lut_ptr; void* valid_block_num_ptr; - // R20 K-stat workspace from Kernel A - const void* pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] fp32 - const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8 + // K-block stats workspace produced by SpargeKStatsKernel + const void* + pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype) + const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8 index_t N_k; - // R21A Phase 4: optional per-head topk (size = nhead_q floats). - // nullptr => use scalar `topk` for all heads. + // Per-head topk (size = nhead_q floats). Required (non-null). const float* topk_per_head; - // R21B: optional per-head cdfthreshd (size = nhead_q floats). - // nullptr => use scalar `cdfthreshd` for all heads. - // Only consulted on topk<=0 path; bench currently always uses topk path. + // Per-head cdfthreshd (size = nhead_q floats). Required (non-null); + // only consulted on topk<=0 path. const float* cdfthreshd_per_head; }; @@ -90,8 +89,8 @@ struct SpargeBlockMapKernel void* valid_block_num_ptr, const void* pooled_k_ws_ptr, const void* sim_k_ws_ptr, - const float* topk_per_head = nullptr, - const float* cdfthreshd_per_head = nullptr) + const float* topk_per_head, + const float* cdfthreshd_per_head) { const index_t N_k = integer_divide_ceil(seqlen_k, kN0); return Kargs{q_ptr, @@ -195,20 +194,15 @@ struct SpargeBlockMapKernel // Shared memory __shared__ char smem[Pipeline::GetSmemSize()]; - // R20 K-stat workspace: pre-offset for this (b, hk). - const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk; + // K-stat workspace: pre-offset for this (b, hk). + const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk; const index_t khead_off = (b * nhead_k + hk) * N_k; const auto* pooled_k_ws = - reinterpret_cast(kargs.pooled_k_ws_ptr) + khead_off * D; - const auto* sim_k_ws = - reinterpret_cast(kargs.sim_k_ws_ptr) + khead_off; + reinterpret_cast(kargs.pooled_k_ws_ptr) + khead_off * D; + const auto* sim_k_ws = reinterpret_cast(kargs.sim_k_ws_ptr) + khead_off; - // R21A Phase 4: per-head topk if provided, else scalar broadcast. - const float topk_eff = - (kargs.topk_per_head != nullptr) ? kargs.topk_per_head[hq] : kargs.topk; - // R21B: per-head cdfthreshd if provided, else scalar broadcast. - const float cdfthreshd_eff = - (kargs.cdfthreshd_per_head != nullptr) ? kargs.cdfthreshd_per_head[hq] : kargs.cdfthreshd; + const float topk_eff = kargs.topk_per_head[hq]; + const float cdfthreshd_eff = kargs.cdfthreshd_per_head[hq]; Pipeline{}(q_window, k_window, diff --git a/include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp index 3ce494f870..893e9a232e 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp @@ -40,15 +40,13 @@ struct SpargeKStatsKernel float simthreshd1; - void* pooled_k_ptr; // [batch, nhead_k, N_k, D] fp32 + void* pooled_k_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype) void* sim_k_ptr; // [batch, nhead_k, N_k] uint8 index_t N_k; - // R21A Phase 4 + R21B fix: optional per-head simthreshd1. - // Buffer is sized [nhead_q] floats to match SpargeAttn upstream contract - // (utils.py:324, Headnum=q.size(1)). Kernel only indexes the first - // nhead_k entries via [hk]. nullptr => use scalar `simthreshd1`. + // Per-head simthreshd1 pointer (size = nhead_q floats; kernel indexes [hk] only). + // Required (non-null); matches SpargeAttn upstream contract. const float* simthreshd1_per_head; }; @@ -62,7 +60,7 @@ struct SpargeKStatsKernel float simthreshd1, void* pooled_k_ptr, void* sim_k_ptr, - const float* simthreshd1_per_head = nullptr) + const float* simthreshd1_per_head) { const index_t N_k = integer_divide_ceil(seqlen_k, kN0); return Kargs{k_ptr, @@ -111,17 +109,15 @@ struct SpargeKStatsKernel {0, 0}, Pipeline::MakeKBlockDistribution()); - const index_t N_k = kargs.N_k; - const index_t khead_off = (b * kargs.nhead_k + hk) * N_k; - auto* pooled_k_out = reinterpret_cast(kargs.pooled_k_ptr) + (khead_off + kb) * D; - auto* sim_k_out = reinterpret_cast(kargs.sim_k_ptr) + (khead_off + kb); + const index_t N_k = kargs.N_k; + const index_t khead_off = (b * kargs.nhead_k + hk) * N_k; + auto* pooled_k_out = + reinterpret_cast(kargs.pooled_k_ptr) + (khead_off + kb) * D; + auto* sim_k_out = reinterpret_cast(kargs.sim_k_ptr) + (khead_off + kb); __shared__ char smem[Pipeline::GetSmemSize()]; - // R21A Phase 4: per-head simthreshd1 if provided, else scalar broadcast. - const float simthreshd1_eff = (kargs.simthreshd1_per_head != nullptr) - ? kargs.simthreshd1_per_head[hk] - : kargs.simthreshd1; + const float simthreshd1_eff = kargs.simthreshd1_per_head[hk]; Pipeline{}(k_window, kargs.seqlen_k, diff --git a/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp index 25e3b964e9..8d813aa578 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp @@ -262,7 +262,7 @@ struct SpargeBlockMapPipeline uint8_t* block_map_ptr, int32_t* lut_ptr, int32_t* valid_block_num_ptr, - const float* __restrict__ pooled_k_ws_ptr, + const KDataType* __restrict__ pooled_k_ws_ptr, const uint8_t* __restrict__ sim_k_ws_ptr, void* smem_ptr) const { @@ -356,10 +356,10 @@ struct SpargeBlockMapPipeline for(index_t kb = 0; kb < N_k; ++kb) { - const float* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread; + const KDataType* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread; float pooled_k_mean[KPerThread]; for(index_t k = 0; k < KPerThread; ++k) - pooled_k_mean[k] = p_kb[k]; + pooled_k_mean[k] = type_convert(p_kb[k]); float dot = 0.f; for(index_t k = 0; k < KPerThread; ++k) @@ -417,8 +417,7 @@ struct SpargeBlockMapPipeline // cdfthreshd path (topk <= 0) still requires normalised scores so the // accumulator `cumulative_prob` matches probabilities. const bool topk_active = (topk > 0.f); - const float inv_sum = - (!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f; + const float inv_sum = (!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f; if(!topk_active) { for(index_t i = tid; i < N_k; i += kBlockSize) diff --git a/include/ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp b/include/ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp index 1cb96d716a..9c122d8dea 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp @@ -49,8 +49,8 @@ struct SpargeKStatsPipeline index_t seqlen_k, index_t kb, float simthreshd1, - float* __restrict__ pooled_k_out, // D floats - uint8_t* __restrict__ sim_k_out, // 1 byte + KDataType* __restrict__ pooled_k_out, // D KDataType (fp16/bf16) + uint8_t* __restrict__ sim_k_out, // 1 byte void* smem_ptr) const { const index_t tid = static_cast(threadIdx.x); @@ -70,19 +70,19 @@ struct SpargeKStatsPipeline const index_t m_idx = lane_id / KThreads; // pooled_k_mean: column sum then cross-warp reduce. - // R21A: drop trailing sync (next cross_warp_reduce has its own leading sync). + // Drop trailing sync (next cross_warp_reduce has its own leading sync). float pooled_k_mean[KPerThread]; Base::template column_reduce_thread_and_warp(k_data, pooled_k_mean); Base::template column_reduce_cross_warp(pooled_k_mean, smem_reduce); for(index_t k = 0; k < KPerThread; ++k) pooled_k_mean[k] *= inv_bs_k; - // R21A: write pooled_k_mean to global early so its register liveness ends here, + // Write pooled_k_mean to global early so its register liveness ends here, // freeing VGPR before k_sum_hat becomes live. if(warp_id == 0 && m_idx == 0) { for(index_t k = 0; k < KPerThread; ++k) - pooled_k_out[k_idx * KPerThread + k] = pooled_k_mean[k]; + pooled_k_out[k_idx * KPerThread + k] = type_convert(pooled_k_mean[k]); } // K row L2 norms + normalised column sum (k_sum_hat) @@ -91,7 +91,7 @@ struct SpargeKStatsPipeline float k_sum_hat[KPerThread]; Base::template column_reduce_normalised(k_data, k_psq, k_sum_hat, bs_k); - // R21A: drop trailing sync (no further smem read; only intra-warp shuffle + global write). + // Drop trailing sync (no further smem read; only intra-warp shuffle + global write). Base::template column_reduce_cross_warp(k_sum_hat, smem_reduce); // sim_k = (||k_sum_hat||^2 / bs_k^2) > simthreshd1