sparse_attn: split KStats kernel, add README + perf charts

- Split SpargeKStatsKernel/Pipeline out of BlockMap (Kernel A produces
  per-block K stats workspace consumed by Kernel B), removing redundant
  K-stat recomputation across Q-blocks.
- Add example/ck_tile/50_sparse_attn/README.md (status vs upstream pinned
  to ae5b629, unported items, usage, references).
- Add example/ck_tile/50_sparse_attn/docs/{speedup_vs_sparsity,kernel_breakdown}.png
  + reusable plot_sparge_perf.py (b=2 h=32 s=16384 d=128 fp16 perf snapshot).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
Gino Lu
2026-05-05 03:13:24 -04:00
parent eca3cb3e0a
commit b00e5449c8
11 changed files with 839 additions and 96 deletions

View File

@@ -105,7 +105,10 @@ auto create_args(int argc, char* argv[])
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name");
.insert("kname", "0", "print kernel name")
.insert("perhead", "0",
"R21A Phase 4: 0=scalar (default), 1=per-head [H] superparam test "
"(varies topk[h] = topk * (1 + 0.5*(h - H/2)/H), simthreshd1 unchanged)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -135,6 +138,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
int perhead = arg_parser.get_int("perhead");
if(nhead_k < 0) nhead_k = nhead;
if(seqlen_k < 0) seqlen_k = seqlen_q;
@@ -231,6 +235,33 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr;
// R21A Phase 4 + R21B fix: per-head superparam buffers, all sized [nhead_q]
// to match SpargeAttn upstream contract (utils.py:324-328, Headnum=q.size(1)).
// K-side kernel reads only the first nhead_k entries via [hk].
ck_tile::DeviceMem topk_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
ck_tile::DeviceMem sim1_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
ck_tile::DeviceMem cdf_per_head_dev (static_cast<size_t>(nhead) * sizeof(float));
if(perhead != 0)
{
std::vector<float> topk_h(nhead);
std::vector<float> sim1_h(nhead);
std::vector<float> cdf_h (nhead);
for(int h = 0; h < nhead; ++h)
{
// small per-head jitter around scalar topk so sparsity differs by head
const float jitter = 0.5f * (static_cast<float>(h - nhead / 2) / nhead);
topk_h[h] = topk * (1.0f + jitter);
sim1_h[h] = simthreshd1; // bit-identical to scalar (kernel reads [0..nhead_k-1])
cdf_h[h] = cdfthreshd;
}
topk_per_head_dev.ToDevice(topk_h.data());
sim1_per_head_dev.ToDevice(sim1_h.data());
cdf_per_head_dev .ToDevice(cdf_h.data());
bmap_args.topk_per_head_ptr = static_cast<const float*>(topk_per_head_dev.GetDeviceBuffer());
bmap_args.simthreshd1_per_head_ptr = static_cast<const float*>(sim1_per_head_dev.GetDeviceBuffer());
bmap_args.cdfthreshd_per_head_ptr = static_cast<const float*>(cdf_per_head_dev.GetDeviceBuffer());
}
// ---- build attention args ----
ck_tile::stream_config stream_cfg;
stream_cfg.stream_id_ = nullptr;