mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
sparse_attn: split KStats kernel, add README + perf charts
- Split SpargeKStatsKernel/Pipeline out of BlockMap (Kernel A produces
per-block K stats workspace consumed by Kernel B), removing redundant
K-stat recomputation across Q-blocks.
- Add example/ck_tile/50_sparse_attn/README.md (status vs upstream pinned
to ae5b629, unported items, usage, references).
- Add example/ck_tile/50_sparse_attn/docs/{speedup_vs_sparsity,kernel_breakdown}.png
+ reusable plot_sparge_perf.py (b=2 h=32 s=16384 d=128 fp16 perf snapshot).
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
@@ -105,7 +105,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name");
|
||||
.insert("kname", "0", "print kernel name")
|
||||
.insert("perhead", "0",
|
||||
"R21A Phase 4: 0=scalar (default), 1=per-head [H] superparam test "
|
||||
"(varies topk[h] = topk * (1 + 0.5*(h - H/2)/H), simthreshd1 unchanged)");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -135,6 +138,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int perhead = arg_parser.get_int("perhead");
|
||||
|
||||
if(nhead_k < 0) nhead_k = nhead;
|
||||
if(seqlen_k < 0) seqlen_k = seqlen_q;
|
||||
@@ -231,6 +235,33 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
|
||||
bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr;
|
||||
|
||||
// R21A Phase 4 + R21B fix: per-head superparam buffers, all sized [nhead_q]
|
||||
// to match SpargeAttn upstream contract (utils.py:324-328, Headnum=q.size(1)).
|
||||
// K-side kernel reads only the first nhead_k entries via [hk].
|
||||
ck_tile::DeviceMem topk_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
|
||||
ck_tile::DeviceMem sim1_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
|
||||
ck_tile::DeviceMem cdf_per_head_dev (static_cast<size_t>(nhead) * sizeof(float));
|
||||
if(perhead != 0)
|
||||
{
|
||||
std::vector<float> topk_h(nhead);
|
||||
std::vector<float> sim1_h(nhead);
|
||||
std::vector<float> cdf_h (nhead);
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
// small per-head jitter around scalar topk so sparsity differs by head
|
||||
const float jitter = 0.5f * (static_cast<float>(h - nhead / 2) / nhead);
|
||||
topk_h[h] = topk * (1.0f + jitter);
|
||||
sim1_h[h] = simthreshd1; // bit-identical to scalar (kernel reads [0..nhead_k-1])
|
||||
cdf_h[h] = cdfthreshd;
|
||||
}
|
||||
topk_per_head_dev.ToDevice(topk_h.data());
|
||||
sim1_per_head_dev.ToDevice(sim1_h.data());
|
||||
cdf_per_head_dev .ToDevice(cdf_h.data());
|
||||
bmap_args.topk_per_head_ptr = static_cast<const float*>(topk_per_head_dev.GetDeviceBuffer());
|
||||
bmap_args.simthreshd1_per_head_ptr = static_cast<const float*>(sim1_per_head_dev.GetDeviceBuffer());
|
||||
bmap_args.cdfthreshd_per_head_ptr = static_cast<const float*>(cdf_per_head_dev.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
// ---- build attention args ----
|
||||
ck_tile::stream_config stream_cfg;
|
||||
stream_cfg.stream_id_ = nullptr;
|
||||
|
||||
Reference in New Issue
Block a user