Files
composable_kernel/example/ck_tile/50_sparse_attn/test_sparge.cpp
Gino Lu b00e5449c8 sparse_attn: split KStats kernel, add README + perf charts
- Split SpargeKStatsKernel/Pipeline out of BlockMap (Kernel A produces
  per-block K stats workspace consumed by Kernel B), removing redundant
  K-stat recomputation across Q-blocks.
- Add example/ck_tile/50_sparse_attn/README.md (status vs upstream pinned
  to ae5b629, unported items, usage, references).
- Add example/ck_tile/50_sparse_attn/docs/{speedup_vs_sparsity,kernel_breakdown}.png
  + reusable plot_sparge_perf.py (b=2 h=32 s=16384 d=128 fp16 perf snapshot).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
2026-05-05 03:13:24 -04:00

466 lines
19 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Unified test for Sparge pipeline: blockmap generation + sparse attention (Jenga/VSA).
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "fmha_fwd_trek.hpp"
#include "sparge_blockmap_trek.hpp"
#include "sparge_tool.hpp"
// ============================================================================
// Helpers
// ============================================================================
template <typename T>
ck_tile::HostTensor<T>
make_qkv_tensor(ck_tile::index_t batch, ck_tile::index_t nhead, ck_tile::index_t seqlen, ck_tile::index_t hdim, bool i_perm)
{
if(i_perm)
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
}
template <typename T>
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
{
auto lens = tensor.get_lengths();
ck_tile::index_t batch = lens[0];
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
ck_tile::index_t hdim = lens[3];
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
for(ck_tile::index_t b = 0; b < batch; ++b)
for(ck_tile::index_t h = 0; h < nhead; ++h)
for(ck_tile::index_t s = 0; s < seqlen; ++s)
for(ck_tile::index_t d = 0; d < hdim; ++d)
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
return out;
}
template <typename T>
auto get_error_tolerance()
{
double rtol = 1e-2;
double atol = 4e-2;
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
atol = 2e-1;
rtol = 2e-1;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename T>
float to_float_for_compare(T value)
{
return static_cast<float>(value);
}
template <>
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<float>(value);
#else
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
#endif
}
// ============================================================================
// Arg parser
// ============================================================================
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser
.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("pipeline", "jenga", "attention pipeline: jenga / vsa")
.insert("b", "1", "batch size")
.insert("h", "4", "num of head for q")
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
.insert("s", "4096", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("topk", "0.3", "topk ratio for blockmap (fraction of K-blocks to keep)")
.insert("cdfthreshd", "-1", "CDF threshold for blockmap (overrides topk if >= 0)")
.insert("simthreshd1", "0.6", "similarity threshold for blockmap")
.insert("prec", "fp16", "data type: fp16/bf16")
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
.insert("operm", "1", "permute output")
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name")
.insert("perhead", "0",
"R21A Phase 4: 0=scalar (default), 1=per-head [H] superparam test "
"(varies topk[h] = topk * (1 + 0.5*(h - H/2)/H), simthreshd1 unchanged)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// ============================================================================
// Main test
// ============================================================================
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
int do_validation = arg_parser.get_int("v");
std::string pipeline = arg_parser.get_str("pipeline");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
float topk = arg_parser.get_float("topk");
float cdfthreshd = arg_parser.get_float("cdfthreshd");
float simthreshd1 = arg_parser.get_float("simthreshd1");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
int perhead = arg_parser.get_int("perhead");
if(nhead_k < 0) nhead_k = nhead;
if(seqlen_k < 0) seqlen_k = seqlen_q;
if(hdim_v < 0) hdim_v = hdim_q;
// If cdfthreshd >= 0, use CDF mode; otherwise use topk mode
if(cdfthreshd >= 0.0f)
topk = -1.0f;
constexpr ck_tile::index_t BLKQ = 64;
constexpr ck_tile::index_t BLKK = 128;
if(hdim_q != 128 || hdim_v != 128)
{
std::cout << "\n>>> TEST SKIPPED <<<\n"
<< "Kernel instances are generated for hdim=128 only.\n";
return true;
}
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
std::string prec_str = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
std::cout << "[" << pipeline << "|" << prec_str
<< "] b=" << batch << " h=" << nhead << " s=" << seqlen_q
<< " d=" << hdim_q << " topk=" << topk
<< " sim1=" << simthreshd1 << std::flush;
// ---- allocate host tensors ----
auto q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
auto k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
auto v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
auto output_host = o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
ck_tile::HostTensor<uint8_t> block_map_host({batch, nhead, num_q_blocks, num_k_blocks});
ck_tile::HostTensor<int32_t> lut_host({batch, nhead, num_q_blocks, num_k_blocks});
ck_tile::HostTensor<int32_t> valid_block_num_host({batch, nhead, num_q_blocks});
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// ---- device tensors ----
ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_dev(output_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_map_dev(block_map_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lut_dev(lut_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem valid_bn_dev(valid_block_num_host.get_element_space_size_in_bytes());
q_dev.ToDevice(q_host.data());
k_dev.ToDevice(k_host.data());
v_dev.ToDevice(v_host.data());
o_dev.SetZero();
block_map_dev.SetZero();
lut_dev.SetZero();
valid_bn_dev.SetZero();
// ---- strides (BHSD when i_perm=true) ----
auto q_strides = q_host.get_strides();
auto k_strides = k_host.get_strides();
auto v_strides = v_host.get_strides();
auto o_strides = output_host.get_strides();
float scale_s = 1.0f / std::sqrt(static_cast<float>(hdim_q));
// ---- build blockmap args ----
sparge_blockmap_traits bmap_traits;
bmap_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
bmap_traits.hdim_q = hdim_q;
sparge_blockmap_args bmap_args;
bmap_args.q_ptr = q_dev.GetDeviceBuffer();
bmap_args.k_ptr = k_dev.GetDeviceBuffer();
bmap_args.batch = batch;
bmap_args.seqlen_q = seqlen_q;
bmap_args.seqlen_k = seqlen_k;
bmap_args.hdim_q = hdim_q;
bmap_args.nhead_q = nhead;
bmap_args.nhead_k = nhead_k;
bmap_args.stride_q = q_strides[i_perm ? 2 : 1];
bmap_args.stride_k = k_strides[i_perm ? 2 : 1];
bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
bmap_args.batch_stride_q = q_strides[0];
bmap_args.batch_stride_k = k_strides[0];
bmap_args.simthreshd1 = simthreshd1;
bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f;
bmap_args.topk = topk;
bmap_args.scale = scale_s;
bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer();
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr;
// R21A Phase 4 + R21B fix: per-head superparam buffers, all sized [nhead_q]
// to match SpargeAttn upstream contract (utils.py:324-328, Headnum=q.size(1)).
// K-side kernel reads only the first nhead_k entries via [hk].
ck_tile::DeviceMem topk_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
ck_tile::DeviceMem sim1_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
ck_tile::DeviceMem cdf_per_head_dev (static_cast<size_t>(nhead) * sizeof(float));
if(perhead != 0)
{
std::vector<float> topk_h(nhead);
std::vector<float> sim1_h(nhead);
std::vector<float> cdf_h (nhead);
for(int h = 0; h < nhead; ++h)
{
// small per-head jitter around scalar topk so sparsity differs by head
const float jitter = 0.5f * (static_cast<float>(h - nhead / 2) / nhead);
topk_h[h] = topk * (1.0f + jitter);
sim1_h[h] = simthreshd1; // bit-identical to scalar (kernel reads [0..nhead_k-1])
cdf_h[h] = cdfthreshd;
}
topk_per_head_dev.ToDevice(topk_h.data());
sim1_per_head_dev.ToDevice(sim1_h.data());
cdf_per_head_dev .ToDevice(cdf_h.data());
bmap_args.topk_per_head_ptr = static_cast<const float*>(topk_per_head_dev.GetDeviceBuffer());
bmap_args.simthreshd1_per_head_ptr = static_cast<const float*>(sim1_per_head_dev.GetDeviceBuffer());
bmap_args.cdfthreshd_per_head_ptr = static_cast<const float*>(cdf_per_head_dev.GetDeviceBuffer());
}
// ---- build attention args ----
ck_tile::stream_config stream_cfg;
stream_cfg.stream_id_ = nullptr;
stream_cfg.time_kernel_ = true;
stream_cfg.log_level_ = kname;
stream_cfg.cold_niters_ = warmup;
stream_cfg.nrepeat_ = repeat;
float avg_ms = -1.0f;
if(pipeline == "jenga")
{
fmha_jenga_fwd_traits attn_traits;
attn_traits.hdim_q = hdim_q;
attn_traits.hdim_v = hdim_v;
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
attn_traits.is_v_rowmajor = true;
attn_traits.mask_type = mask_enum::no_mask;
attn_traits.bm0 = BLKQ;
fmha_jenga_fwd_args attn_args;
attn_args.q_ptr = q_dev.GetDeviceBuffer();
attn_args.k_ptr = k_dev.GetDeviceBuffer();
attn_args.v_ptr = v_dev.GetDeviceBuffer();
attn_args.block_relation_onehot_ptr = block_map_dev.GetDeviceBuffer();
attn_args.o_ptr = o_dev.GetDeviceBuffer();
attn_args.seqlen_q = seqlen_q;
attn_args.seqlen_k = seqlen_k;
attn_args.batch = batch;
attn_args.max_seqlen_q = seqlen_q;
attn_args.hdim_q = hdim_q;
attn_args.hdim_v = hdim_v;
attn_args.nhead_q = nhead;
attn_args.nhead_k = nhead_k;
attn_args.scale_s = scale_s;
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
attn_args.batch_stride_q = q_strides[0];
attn_args.batch_stride_k = k_strides[0];
attn_args.batch_stride_v = v_strides[0];
attn_args.batch_stride_o = o_strides[0];
attn_args.window_size_left = -1;
attn_args.window_size_right = -1;
attn_args.mask_type = 0;
avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
}
else if(pipeline == "vsa")
{
fmha_vsa_fwd_traits attn_traits;
attn_traits.hdim_q = hdim_q;
attn_traits.hdim_v = hdim_v;
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
attn_traits.is_v_rowmajor = true;
attn_traits.mask_type = mask_enum::no_mask;
attn_traits.bm0 = BLKQ;
fmha_vsa_fwd_args attn_args;
attn_args.q_ptr = q_dev.GetDeviceBuffer();
attn_args.k_ptr = k_dev.GetDeviceBuffer();
attn_args.v_ptr = v_dev.GetDeviceBuffer();
attn_args.lut_ptr = lut_dev.GetDeviceBuffer();
attn_args.valid_block_num_ptr = valid_bn_dev.GetDeviceBuffer();
attn_args.o_ptr = o_dev.GetDeviceBuffer();
attn_args.seqlen_q = seqlen_q;
attn_args.seqlen_k = seqlen_k;
attn_args.batch = batch;
attn_args.max_seqlen_q = seqlen_q;
attn_args.hdim_q = hdim_q;
attn_args.hdim_v = hdim_v;
attn_args.nhead_q = nhead;
attn_args.nhead_k = nhead_k;
attn_args.scale_s = scale_s;
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
attn_args.batch_stride_q = q_strides[0];
attn_args.batch_stride_k = k_strides[0];
attn_args.batch_stride_v = v_strides[0];
attn_args.batch_stride_o = o_strides[0];
attn_args.window_size_left = -1;
attn_args.window_size_right = -1;
attn_args.mask_type = 0;
avg_ms = sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
}
else
{
std::cerr << "Unknown pipeline: " << pipeline << " (use jenga or vsa)\n";
return false;
}
// ---- TFLOPS calculation (dense FMHA formula, so sparsity gains show as higher TFLOPS) ----
std::size_t flop = static_cast<std::size_t>(batch) * nhead *
(static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_q +
static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_v);
float tflops = (avg_ms > 0.f) ? static_cast<float>(flop) / 1.E9f / avg_ms : 0.f;
if(avg_ms > 0.f)
{
std::cout << std::fixed << ", " << std::setprecision(3) << avg_ms << " ms, "
<< std::setprecision(2) << tflops << " TFlops" << std::flush;
}
// ---- copy results back ----
o_dev.FromDevice(output_host.data());
block_map_dev.FromDevice(block_map_host.data());
// ---- count active blocks ----
ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks;
ck_tile::index_t active_blocks = 0;
for(size_t i = 0; i < block_map_host.mData.size(); ++i)
if(block_map_host.mData[i])
active_blocks++;
float actual_sparsity = 1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity
<< "(" << active_blocks << "/" << total_blocks << ")" << std::flush;
// ---- validation ----
bool pass = true;
if(do_validation)
{
auto q_ref = to_bhsd(q_host, i_perm);
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
ck_tile::reference_blocked_attention<T, uint8_t>(
q_ref, k_ref, v_ref, block_map_host, output_ref, BLKQ, BLKK, scale_s);
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
size_t num_errors = 0;
auto output_host_bhsd = to_bhsd(output_host, o_perm);
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
{
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
float ref_val = to_float_for_compare(output_ref.mData[i]);
float diff = std::abs(gpu_val - ref_val);
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
max_diff = std::max(max_diff, diff);
if(diff > atol && rel_diff > rtol)
num_errors++;
}
pass = (num_errors == 0);
std::cout << ", " << (pass ? "PASS" : "FAIL")
<< "(err=" << num_errors << "/" << output_host_bhsd.mData.size()
<< " maxdiff=" << max_diff << ")";
}
std::cout << std::endl;
return pass;
}
// ============================================================================
// Main
// ============================================================================
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cerr << "Failed to parse arguments\n";
return -1;
}
std::string prec = arg_parser.get_str("prec");
bool test_result = false;
if(prec == "fp16")
{
test_result = run_test<ck_tile::half_t>(arg_parser);
}
else if(prec == "bf16")
{
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{
std::cerr << "Unsupported precision: " << prec << "\n";
return -1;
}
return test_result ? 0 : -1;
}