refactor(sparse_attn): caller-owned workspace + dtype-aware sizing

Replace process-lifetime lazy hipMalloc K-stats workspace with a caller-owned
buffer; expose sparge_blockmap_get_workspace_size() / compute_workspace_layout()
host helpers. Split the combined sparge_blockmap_fwd into stage launchers
(sparge_kstats_fwd_oneshot + sparge_blockmap_only_fwd_oneshot) so the chained
launch is timed end-to-end.

Make pooled_k storage dtype follow KDataType (fp16/bf16) instead of fp32 to halve
workspace footprint and match dense-FMHA precision. Tighten per-head superparam
pointers to required (non-null) and assert N_k <= 256 in jenga MakeKargs to
document the 256-bool LDS staging cap. Drop the obsolete VSA extra-LDS staging.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
Gino Lu
2026-05-17 02:34:23 -04:00
parent 668e107282
commit 7103eacc99
9 changed files with 402 additions and 399 deletions

View File

@@ -67,47 +67,24 @@ using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel<bmap_fp16_pipeline>;
using kstats_fp16_pipeline = ck_tile::SpargeKStatsPipeline<bmap_fp16_problem>;
using kstats_fp16_kernel = ck_tile::SpargeKStatsKernel<kstats_fp16_pipeline>;
// ============================================================================
// bf16: D=128, kM0=64, kN0=128
// ============================================================================
// bf16: dtype-independent aliases share fp16 chain; only problem differs.
using bmap_bf16_block_tile = bmap_fp16_block_tile;
using bmap_bf16_shape = bmap_fp16_shape;
using bmap_bf16_trait = bmap_fp16_trait;
using bmap_bf16_variant = bmap_fp16_variant;
using bmap_bf16_mask = bmap_fp16_mask;
using bmap_bf16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>;
using bmap_bf16_shape =
ck_tile::TileFmhaShape<bmap_bf16_block_tile,
ck_tile::sequence<4, 1, 1>,
ck_tile::sequence<16, 16, 16>,
ck_tile::sequence<4, 1, 1>,
ck_tile::sequence<16, 16, 16>,
true>;
using bmap_bf16_trait = ck_tile::TileFmhaTraits<true, // kPadSeqLenQ
true, // kPadSeqLenK
true, // kPadHeadDimQ
true, // kPadHeadDimV
false, // kHasLogitsSoftCap
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false, // kStoreLSE
false, // kHasDropout
false, // kHasRandVal
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE,
-1,
false>;
using bmap_bf16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>;
using bmap_bf16_mask = ck_tile::GenericAttentionMask<false>;
using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::bf16_t, // QDataType
ck_tile::bf16_t, // KDataType
ck_tile::bf16_t, // VDataType
float, // SaccDataType
float, // SMPLComputeDataType
ck_tile::bf16_t, // BiasDataType
uint8_t, // RandValOutputDataType
float, // LSEDataType
ck_tile::bf16_t, // PDataType
float, // OaccDataType
ck_tile::bf16_t, // ODataType
using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::bf16_t, // QDataType
ck_tile::bf16_t, // KDataType
ck_tile::bf16_t, // VDataType
float, // SaccDataType
float, // SMPLComputeDataType
ck_tile::bf16_t, // BiasDataType
uint8_t, // RandValOutputDataType
float, // LSEDataType
ck_tile::bf16_t, // PDataType
float, // OaccDataType
ck_tile::bf16_t, // ODataType
bmap_bf16_shape,
false, // kIsGroupMode
bmap_bf16_variant,
@@ -122,168 +99,168 @@ using kstats_bf16_pipeline = ck_tile::SpargeKStatsPipeline<bmap_bf16_problem>;
using kstats_bf16_kernel = ck_tile::SpargeKStatsKernel<kstats_bf16_pipeline>;
// ============================================================================
// Internal K-stat workspace (R20): process-lifetime lazy hipMalloc, sized
// to the largest (batch, nhead_k, N_k, D) seen so far. Caller API unchanged.
// Workspace layout: caller owns the buffer; we just compute size + offsets.
// Layout = [pooled_k (KDataType) | sim_k (uint8)]. sim_k follows pooled_k with
// no padding (uint8 has alignment 1).
// ============================================================================
namespace {
struct KStatsWorkspace
constexpr int sparge_kN0_for(int hdim_q)
{
void* pooled_k_dev = nullptr; // [batch, nhead_k, N_k, D] fp32
void* sim_k_dev = nullptr; // [batch, nhead_k, N_k] uint8
size_t pooled_k_bytes = 0;
size_t sim_k_bytes = 0;
void ensure(int batch, int nhead_k, int N_k, int D)
{
const size_t need_p = static_cast<size_t>(batch) * nhead_k * N_k * D * sizeof(float);
const size_t need_s = static_cast<size_t>(batch) * nhead_k * N_k * sizeof(uint8_t);
if(need_p > pooled_k_bytes)
{
if(pooled_k_dev != nullptr) (void)hipFree(pooled_k_dev);
(void)hipMalloc(&pooled_k_dev, need_p);
pooled_k_bytes = need_p;
}
if(need_s > sim_k_bytes)
{
if(sim_k_dev != nullptr) (void)hipFree(sim_k_dev);
(void)hipMalloc(&sim_k_dev, need_s);
sim_k_bytes = need_s;
}
}
};
KStatsWorkspace& g_kstats_ws()
{
static KStatsWorkspace ws;
return ws;
// d=128 instances use kN0=128 (see bmap_fp16_block_tile).
return (hdim_q == 128) ? 128 : 0;
}
template <typename KStatsKernel, typename BlockMapKernel>
void launch_kstats_then_blockmap(sparge_blockmap_args args, const ck_tile::stream_config& s)
size_t dtype_bytes(const std::string& dt)
{
const int N_k = ck_tile::integer_divide_ceil(args.seqlen_k, BlockMapKernel::kN0);
const int D = BlockMapKernel::D;
auto& ws = g_kstats_ws();
ws.ensure(args.batch, args.nhead_k, N_k, D);
if(dt == "fp16" || dt == "bf16")
return 2;
return 0;
}
// Stage 1: K stats
{
auto [kargs, grids] =
sparge_kstats_create_kargs_and_grids<KStatsKernel>(args, ws.pooled_k_dev, ws.sim_k_dev);
const dim3 blocks = KStatsKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = KStatsKernel::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(KStatsKernel{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
}
// Stage 2: block_map (reads ws)
{
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<BlockMapKernel>(
args, ws.pooled_k_dev, ws.sim_k_dev);
const dim3 blocks = BlockMapKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = BlockMapKernel::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(BlockMapKernel{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
}
} // namespace
sparge_blockmap_workspace_layout
sparge_blockmap_compute_workspace_layout(sparge_blockmap_traits traits, sparge_blockmap_args args)
{
const int kN0 = sparge_kN0_for(traits.hdim_q);
const int N_k = (kN0 > 0) ? ck_tile::integer_divide_ceil(args.seqlen_k, kN0) : 0;
const int D = traits.hdim_q;
const size_t element_bytes = dtype_bytes(traits.data_type);
sparge_blockmap_workspace_layout layout{};
layout.pooled_k_offset = 0;
layout.pooled_k_bytes =
static_cast<size_t>(args.batch) * args.nhead_k * N_k * D * element_bytes;
layout.sim_k_offset = layout.pooled_k_bytes;
layout.sim_k_bytes = static_cast<size_t>(args.batch) * args.nhead_k * N_k * sizeof(uint8_t);
layout.total_bytes = layout.sim_k_offset + layout.sim_k_bytes;
return layout;
}
// ============================================================================
// Stage launchers: read args.workspace_ptr split per layout, run one kernel.
// ============================================================================
namespace {
template <typename KStatsKernel>
void launch_kstats_only(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& s)
{
const auto layout = sparge_blockmap_compute_workspace_layout(traits, args);
auto* ws_base = static_cast<char*>(args.workspace_ptr);
void* pooled_k_ptr = ws_base + layout.pooled_k_offset;
void* sim_k_ptr = ws_base + layout.sim_k_offset;
auto [kargs, grids] =
sparge_kstats_create_kargs_and_grids<KStatsKernel>(args, pooled_k_ptr, sim_k_ptr);
const dim3 blocks = KStatsKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = KStatsKernel::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(KStatsKernel{}, grids, blocks, 0, kargs)(s);
}
template <typename BlockMapKernel>
void launch_blockmap_only(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& s)
{
const auto layout = sparge_blockmap_compute_workspace_layout(traits, args);
auto* ws_base = static_cast<char*>(args.workspace_ptr);
void* pooled_k_ptr = ws_base + layout.pooled_k_offset;
void* sim_k_ptr = ws_base + layout.sim_k_offset;
auto [kargs, grids] =
sparge_blockmap_create_kargs_and_grids<BlockMapKernel>(args, pooled_k_ptr, sim_k_ptr);
const dim3 blocks = BlockMapKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = BlockMapKernel::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(BlockMapKernel{}, grids, blocks, 0, kargs)(s);
}
} // namespace
// ============================================================================
// Dispatch
// Oneshot stages (no timing): caller chains them via launch_kernel.
// ============================================================================
float sparge_blockmap_fwd(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& s)
void sparge_kstats_fwd_oneshot(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& s)
{
if(traits.data_type == "fp16" && traits.hdim_q == 128)
{
if(s.log_level_ > 0)
std::cout << ", sparge_blockmap_fp16_d128" << std::flush;
return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) {
launch_kstats_then_blockmap<kstats_fp16_kernel, bmap_fp16_kernel>(args, s_);
});
}
if(traits.data_type == "bf16" && traits.hdim_q == 128)
{
if(s.log_level_ > 0)
std::cout << ", sparge_blockmap_bf16_d128" << std::flush;
return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) {
launch_kstats_then_blockmap<kstats_bf16_kernel, bmap_bf16_kernel>(args, s_);
});
}
if(s.log_level_ > 0)
std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type
<< ", hdim_q=" << traits.hdim_q << ")" << std::endl;
return -1.f;
}
// ============================================================================
// Oneshot version: launches kernel without timing wrapper
// ============================================================================
void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& s)
{
if(traits.data_type == "fp16" && traits.hdim_q == 128)
{
launch_kstats_then_blockmap<kstats_fp16_kernel, bmap_fp16_kernel>(args, s);
launch_kstats_only<kstats_fp16_kernel>(traits, args, s);
return;
}
if(traits.data_type == "bf16" && traits.hdim_q == 128)
{
launch_kstats_then_blockmap<kstats_bf16_kernel, bmap_bf16_kernel>(args, s);
launch_kstats_only<kstats_bf16_kernel>(traits, args, s);
return;
}
std::cerr << "sparge_blockmap_fwd_oneshot: unsupported config (data_type=" << traits.data_type
std::cerr << "sparge_kstats_fwd_oneshot: unsupported config (data_type=" << traits.data_type
<< ", hdim_q=" << traits.hdim_q << ")" << std::endl;
}
void sparge_blockmap_only_fwd_oneshot(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& s)
{
if(traits.data_type == "fp16" && traits.hdim_q == 128)
{
launch_blockmap_only<bmap_fp16_kernel>(traits, args, s);
return;
}
if(traits.data_type == "bf16" && traits.hdim_q == 128)
{
launch_blockmap_only<bmap_bf16_kernel>(traits, args, s);
return;
}
std::cerr << "sparge_blockmap_only_fwd_oneshot: unsupported config (data_type="
<< traits.data_type << ", hdim_q=" << traits.hdim_q << ")" << std::endl;
}
// ============================================================================
// Combined functions: blockmap + attention timed together via launch_kernel
// Combined functions: kstats + blockmap + attention timed together.
// ============================================================================
float sparge_jenga_fwd(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a,
fmha_jenga_fwd_traits attn_t, fmha_jenga_fwd_args attn_a,
float sparge_jenga_fwd(sparge_blockmap_traits bmap_t,
sparge_blockmap_args bmap_a,
fmha_jenga_fwd_traits attn_t,
fmha_jenga_fwd_args attn_a,
const ck_tile::stream_config& s)
{
if(s.log_level_ > 0)
std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
<< ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q
<< std::flush;
std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
<< ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
<< ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q << std::flush;
return ck_tile::launch_kernel(
s,
[=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); },
[=](const ck_tile::stream_config& s_) {
sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_);
sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_);
},
[=](const ck_tile::stream_config& s_) {
fmha_jenga_fwd_oneshot(attn_t, attn_a, s_);
});
[=](const ck_tile::stream_config& s_) { fmha_jenga_fwd_oneshot(attn_t, attn_a, s_); });
}
float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a,
fmha_vsa_fwd_traits attn_t, fmha_vsa_fwd_args attn_a,
float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t,
sparge_blockmap_args bmap_a,
fmha_vsa_fwd_traits attn_t,
fmha_vsa_fwd_args attn_a,
const ck_tile::stream_config& s)
{
if(s.log_level_ > 0)
std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
<< ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q
<< std::flush;
std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
<< ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
<< ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q << std::flush;
return ck_tile::launch_kernel(
s,
[=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); },
[=](const ck_tile::stream_config& s_) {
sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_);
sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_);
},
[=](const ck_tile::stream_config& s_) {
fmha_vsa_fwd_oneshot(attn_t, attn_a, s_);
});
[=](const ck_tile::stream_config& s_) { fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); });
}

View File

@@ -48,14 +48,23 @@ struct sparge_blockmap_args
void* lut_ptr;
void* valid_block_num_ptr;
// R21A Phase 4 + R21B fix: optional per-head superparams. nullptr => use scalar.
// Buffer sizes match SpargeAttn upstream contract (utils.py:324-328: all sized
// by Headnum=q.size(1)=nhead_q). K-side kernel still indexes [hk] into the
// first nhead_k entries — for MHA equivalent to old [nhead_k] sizing, for
// MQA/GQA aligns to upstream tuned ckpt layout.
const float* simthreshd1_per_head_ptr = nullptr; // size = nhead_q floats (kernel reads [0..nhead_k-1])
const float* cdfthreshd_per_head_ptr = nullptr; // size = nhead_q floats
const float* topk_per_head_ptr = nullptr; // size = nhead_q floats
// Caller-owned K-stats workspace; size from sparge_blockmap_get_workspace_size.
// Internal layout (pooled_k then sim_k) given by sparge_blockmap_workspace_layout.
void* workspace_ptr = nullptr;
// size = nhead_q to match SpargeAttn upstream hyperparameter_check
const float* simthreshd1_per_head_ptr = nullptr;
const float* cdfthreshd_per_head_ptr = nullptr;
const float* topk_per_head_ptr = nullptr;
};
struct sparge_blockmap_workspace_layout
{
size_t pooled_k_offset; // bytes from workspace_ptr
size_t pooled_k_bytes;
size_t sim_k_offset; // bytes from workspace_ptr
size_t sim_k_bytes;
size_t total_bytes;
};
struct sparge_blockmap_traits
@@ -127,19 +136,36 @@ auto sparge_kstats_create_kargs_and_grids(sparge_blockmap_args args,
// ============================================================================
// Hand-written template instantiation dispatch
// ============================================================================
float sparge_blockmap_fwd(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& stream_config);
void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& stream_config);
// Workspace sizing helpers (host, no template instantiation needed).
sparge_blockmap_workspace_layout
sparge_blockmap_compute_workspace_layout(sparge_blockmap_traits traits, sparge_blockmap_args args);
// Combined functions: blockmap + attention with unified timing
float sparge_jenga_fwd(sparge_blockmap_traits, sparge_blockmap_args,
fmha_jenga_fwd_traits, fmha_jenga_fwd_args,
inline size_t sparge_blockmap_get_workspace_size(sparge_blockmap_traits traits,
sparge_blockmap_args args)
{
return sparge_blockmap_compute_workspace_layout(traits, args).total_bytes;
}
// Stage 1: K-stats only. Writes pooled_k + sim_k into args.workspace_ptr.
void sparge_kstats_fwd_oneshot(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& stream_config);
// Stage 2: block_map only. Reads pooled_k + sim_k from args.workspace_ptr.
void sparge_blockmap_only_fwd_oneshot(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& stream_config);
// Combined functions: kstats + blockmap + attention with unified timing.
float sparge_jenga_fwd(sparge_blockmap_traits,
sparge_blockmap_args,
fmha_jenga_fwd_traits,
fmha_jenga_fwd_args,
const ck_tile::stream_config&);
float sparge_vsa_fwd_combined(sparge_blockmap_traits, sparge_blockmap_args,
fmha_vsa_fwd_traits, fmha_vsa_fwd_args,
float sparge_vsa_fwd_combined(sparge_blockmap_traits,
sparge_blockmap_args,
fmha_vsa_fwd_traits,
fmha_vsa_fwd_args,
const ck_tile::stream_config&);

View File

@@ -25,8 +25,11 @@
// ============================================================================
template <typename T>
ck_tile::HostTensor<T>
make_qkv_tensor(ck_tile::index_t batch, ck_tile::index_t nhead, ck_tile::index_t seqlen, ck_tile::index_t hdim, bool i_perm)
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t seqlen,
ck_tile::index_t hdim,
bool i_perm)
{
if(i_perm)
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
@@ -86,8 +89,7 @@ float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser
.insert("v", "1", "0:no validation, 1:cpu validation")
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("pipeline", "jenga", "attention pipeline: jenga / vsa")
.insert("b", "1", "batch size")
.insert("h", "4", "num of head for q")
@@ -105,10 +107,7 @@ auto create_args(int argc, char* argv[])
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name")
.insert("perhead", "0",
"R21A Phase 4: 0=scalar (default), 1=per-head [H] superparam test "
"(varies topk[h] = topk * (1 + 0.5*(h - H/2)/H), simthreshd1 unchanged)");
.insert("kname", "0", "print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -120,29 +119,31 @@ auto create_args(int argc, char* argv[])
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
int do_validation = arg_parser.get_int("v");
std::string pipeline = arg_parser.get_str("pipeline");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
float topk = arg_parser.get_float("topk");
float cdfthreshd = arg_parser.get_float("cdfthreshd");
float simthreshd1 = arg_parser.get_float("simthreshd1");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
int perhead = arg_parser.get_int("perhead");
int do_validation = arg_parser.get_int("v");
std::string pipeline = arg_parser.get_str("pipeline");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
float topk = arg_parser.get_float("topk");
float cdfthreshd = arg_parser.get_float("cdfthreshd");
float simthreshd1 = arg_parser.get_float("simthreshd1");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
if(nhead_k < 0) nhead_k = nhead;
if(seqlen_k < 0) seqlen_k = seqlen_q;
if(hdim_v < 0) hdim_v = hdim_q;
if(nhead_k < 0)
nhead_k = nhead;
if(seqlen_k < 0)
seqlen_k = seqlen_q;
if(hdim_v < 0)
hdim_v = hdim_q;
// If cdfthreshd >= 0, use CDF mode; otherwise use topk mode
if(cdfthreshd >= 0.0f)
@@ -162,15 +163,14 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
std::string prec_str = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
std::cout << "[" << pipeline << "|" << prec_str
<< "] b=" << batch << " h=" << nhead << " s=" << seqlen_q
<< " d=" << hdim_q << " topk=" << topk
<< " sim1=" << simthreshd1 << std::flush;
std::cout << "[" << pipeline << "|" << prec_str << "] b=" << batch << " h=" << nhead
<< " s=" << seqlen_q << " d=" << hdim_q << " topk=" << topk << " sim1=" << simthreshd1
<< std::flush;
// ---- allocate host tensors ----
auto q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
auto k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
auto v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
auto q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
auto k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
auto v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
auto output_host = o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
@@ -213,62 +213,61 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
bmap_traits.hdim_q = hdim_q;
sparge_blockmap_args bmap_args;
bmap_args.q_ptr = q_dev.GetDeviceBuffer();
bmap_args.k_ptr = k_dev.GetDeviceBuffer();
bmap_args.batch = batch;
bmap_args.seqlen_q = seqlen_q;
bmap_args.seqlen_k = seqlen_k;
bmap_args.hdim_q = hdim_q;
bmap_args.nhead_q = nhead;
bmap_args.nhead_k = nhead_k;
bmap_args.stride_q = q_strides[i_perm ? 2 : 1];
bmap_args.stride_k = k_strides[i_perm ? 2 : 1];
bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
bmap_args.batch_stride_q = q_strides[0];
bmap_args.batch_stride_k = k_strides[0];
bmap_args.simthreshd1 = simthreshd1;
bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f;
bmap_args.topk = topk;
bmap_args.scale = scale_s;
bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer();
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
bmap_args.q_ptr = q_dev.GetDeviceBuffer();
bmap_args.k_ptr = k_dev.GetDeviceBuffer();
bmap_args.batch = batch;
bmap_args.seqlen_q = seqlen_q;
bmap_args.seqlen_k = seqlen_k;
bmap_args.hdim_q = hdim_q;
bmap_args.nhead_q = nhead;
bmap_args.nhead_k = nhead_k;
bmap_args.stride_q = q_strides[i_perm ? 2 : 1];
bmap_args.stride_k = k_strides[i_perm ? 2 : 1];
bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
bmap_args.batch_stride_q = q_strides[0];
bmap_args.batch_stride_k = k_strides[0];
bmap_args.simthreshd1 = simthreshd1;
bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f;
bmap_args.topk = topk;
bmap_args.scale = scale_s;
bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer();
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr;
// R21A Phase 4 + R21B fix: per-head superparam buffers, all sized [nhead_q]
// to match SpargeAttn upstream contract (utils.py:324-328, Headnum=q.size(1)).
// K-stats workspace: caller-owned, sized via host helper, allocated once outside any timing.
const size_t ws_bytes = sparge_blockmap_get_workspace_size(bmap_traits, bmap_args);
ck_tile::DeviceMem kstats_ws_dev(ws_bytes);
bmap_args.workspace_ptr = kstats_ws_dev.GetDeviceBuffer();
// Per-head superparam buffers, all sized [nhead_q] to match SpargeAttn upstream contract.
// K-side kernel reads only the first nhead_k entries via [hk].
// Filled with scalar broadcast; per-head index correctness verified by separate unit test.
ck_tile::DeviceMem topk_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
ck_tile::DeviceMem sim1_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
ck_tile::DeviceMem cdf_per_head_dev (static_cast<size_t>(nhead) * sizeof(float));
if(perhead != 0)
ck_tile::DeviceMem cdf_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
{
std::vector<float> topk_h(nhead);
std::vector<float> sim1_h(nhead);
std::vector<float> cdf_h (nhead);
for(int h = 0; h < nhead; ++h)
{
// small per-head jitter around scalar topk so sparsity differs by head
const float jitter = 0.5f * (static_cast<float>(h - nhead / 2) / nhead);
topk_h[h] = topk * (1.0f + jitter);
sim1_h[h] = simthreshd1; // bit-identical to scalar (kernel reads [0..nhead_k-1])
cdf_h[h] = cdfthreshd;
}
std::vector<float> topk_h(nhead, topk);
std::vector<float> sim1_h(nhead, simthreshd1);
std::vector<float> cdf_h(nhead, cdfthreshd);
topk_per_head_dev.ToDevice(topk_h.data());
sim1_per_head_dev.ToDevice(sim1_h.data());
cdf_per_head_dev .ToDevice(cdf_h.data());
bmap_args.topk_per_head_ptr = static_cast<const float*>(topk_per_head_dev.GetDeviceBuffer());
bmap_args.simthreshd1_per_head_ptr = static_cast<const float*>(sim1_per_head_dev.GetDeviceBuffer());
bmap_args.cdfthreshd_per_head_ptr = static_cast<const float*>(cdf_per_head_dev.GetDeviceBuffer());
cdf_per_head_dev.ToDevice(cdf_h.data());
bmap_args.topk_per_head_ptr =
static_cast<const float*>(topk_per_head_dev.GetDeviceBuffer());
bmap_args.simthreshd1_per_head_ptr =
static_cast<const float*>(sim1_per_head_dev.GetDeviceBuffer());
bmap_args.cdfthreshd_per_head_ptr =
static_cast<const float*>(cdf_per_head_dev.GetDeviceBuffer());
}
// ---- build attention args ----
ck_tile::stream_config stream_cfg;
stream_cfg.stream_id_ = nullptr;
stream_cfg.stream_id_ = nullptr;
stream_cfg.time_kernel_ = true;
stream_cfg.log_level_ = kname;
stream_cfg.log_level_ = kname;
stream_cfg.cold_niters_ = warmup;
stream_cfg.nrepeat_ = repeat;
stream_cfg.nrepeat_ = repeat;
float avg_ms = -1.0f;
@@ -283,35 +282,35 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
attn_traits.bm0 = BLKQ;
fmha_jenga_fwd_args attn_args;
attn_args.q_ptr = q_dev.GetDeviceBuffer();
attn_args.k_ptr = k_dev.GetDeviceBuffer();
attn_args.v_ptr = v_dev.GetDeviceBuffer();
attn_args.q_ptr = q_dev.GetDeviceBuffer();
attn_args.k_ptr = k_dev.GetDeviceBuffer();
attn_args.v_ptr = v_dev.GetDeviceBuffer();
attn_args.block_relation_onehot_ptr = block_map_dev.GetDeviceBuffer();
attn_args.o_ptr = o_dev.GetDeviceBuffer();
attn_args.seqlen_q = seqlen_q;
attn_args.seqlen_k = seqlen_k;
attn_args.batch = batch;
attn_args.max_seqlen_q = seqlen_q;
attn_args.hdim_q = hdim_q;
attn_args.hdim_v = hdim_v;
attn_args.nhead_q = nhead;
attn_args.nhead_k = nhead_k;
attn_args.scale_s = scale_s;
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
attn_args.batch_stride_q = q_strides[0];
attn_args.batch_stride_k = k_strides[0];
attn_args.batch_stride_v = v_strides[0];
attn_args.batch_stride_o = o_strides[0];
attn_args.window_size_left = -1;
attn_args.window_size_right = -1;
attn_args.mask_type = 0;
attn_args.o_ptr = o_dev.GetDeviceBuffer();
attn_args.seqlen_q = seqlen_q;
attn_args.seqlen_k = seqlen_k;
attn_args.batch = batch;
attn_args.max_seqlen_q = seqlen_q;
attn_args.hdim_q = hdim_q;
attn_args.hdim_v = hdim_v;
attn_args.nhead_q = nhead;
attn_args.nhead_k = nhead_k;
attn_args.scale_s = scale_s;
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
attn_args.batch_stride_q = q_strides[0];
attn_args.batch_stride_k = k_strides[0];
attn_args.batch_stride_v = v_strides[0];
attn_args.batch_stride_o = o_strides[0];
attn_args.window_size_left = -1;
attn_args.window_size_right = -1;
attn_args.mask_type = 0;
avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
}
@@ -326,38 +325,39 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
attn_traits.bm0 = BLKQ;
fmha_vsa_fwd_args attn_args;
attn_args.q_ptr = q_dev.GetDeviceBuffer();
attn_args.k_ptr = k_dev.GetDeviceBuffer();
attn_args.v_ptr = v_dev.GetDeviceBuffer();
attn_args.lut_ptr = lut_dev.GetDeviceBuffer();
attn_args.q_ptr = q_dev.GetDeviceBuffer();
attn_args.k_ptr = k_dev.GetDeviceBuffer();
attn_args.v_ptr = v_dev.GetDeviceBuffer();
attn_args.lut_ptr = lut_dev.GetDeviceBuffer();
attn_args.valid_block_num_ptr = valid_bn_dev.GetDeviceBuffer();
attn_args.o_ptr = o_dev.GetDeviceBuffer();
attn_args.seqlen_q = seqlen_q;
attn_args.seqlen_k = seqlen_k;
attn_args.batch = batch;
attn_args.max_seqlen_q = seqlen_q;
attn_args.hdim_q = hdim_q;
attn_args.hdim_v = hdim_v;
attn_args.nhead_q = nhead;
attn_args.nhead_k = nhead_k;
attn_args.scale_s = scale_s;
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
attn_args.batch_stride_q = q_strides[0];
attn_args.batch_stride_k = k_strides[0];
attn_args.batch_stride_v = v_strides[0];
attn_args.batch_stride_o = o_strides[0];
attn_args.window_size_left = -1;
attn_args.window_size_right = -1;
attn_args.mask_type = 0;
attn_args.o_ptr = o_dev.GetDeviceBuffer();
attn_args.seqlen_q = seqlen_q;
attn_args.seqlen_k = seqlen_k;
attn_args.batch = batch;
attn_args.max_seqlen_q = seqlen_q;
attn_args.hdim_q = hdim_q;
attn_args.hdim_v = hdim_v;
attn_args.nhead_q = nhead;
attn_args.nhead_k = nhead_k;
attn_args.scale_s = scale_s;
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
attn_args.batch_stride_q = q_strides[0];
attn_args.batch_stride_k = k_strides[0];
attn_args.batch_stride_v = v_strides[0];
attn_args.batch_stride_o = o_strides[0];
attn_args.window_size_left = -1;
attn_args.window_size_right = -1;
attn_args.mask_type = 0;
avg_ms = sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
avg_ms =
sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
}
else
{
@@ -367,8 +367,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
// ---- TFLOPS calculation (dense FMHA formula, so sparsity gains show as higher TFLOPS) ----
std::size_t flop = static_cast<std::size_t>(batch) * nhead *
(static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_q +
static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_v);
(static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_q +
static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_v);
float tflops = (avg_ms > 0.f) ? static_cast<float>(flop) / 1.E9f / avg_ms : 0.f;
if(avg_ms > 0.f)
@@ -382,14 +382,15 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
block_map_dev.FromDevice(block_map_host.data());
// ---- count active blocks ----
ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks;
ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks;
ck_tile::index_t active_blocks = 0;
for(size_t i = 0; i < block_map_host.mData.size(); ++i)
if(block_map_host.mData[i])
active_blocks++;
float actual_sparsity = 1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity
<< "(" << active_blocks << "/" << total_blocks << ")" << std::flush;
float actual_sparsity =
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity << "(" << active_blocks
<< "/" << total_blocks << ")" << std::flush;
// ---- validation ----
bool pass = true;
@@ -405,8 +406,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
size_t num_errors = 0;
float max_diff = 0.0f;
size_t num_errors = 0;
auto output_host_bhsd = to_bhsd(output_host, o_perm);
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
@@ -423,9 +424,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
}
pass = (num_errors == 0);
std::cout << ", " << (pass ? "PASS" : "FAIL")
<< "(err=" << num_errors << "/" << output_host_bhsd.mData.size()
<< " maxdiff=" << max_diff << ")";
std::cout << ", " << (pass ? "PASS" : "FAIL") << "(err=" << num_errors << "/"
<< output_host_bhsd.mData.size() << " maxdiff=" << max_diff << ")";
}
std::cout << std::endl;

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <cassert>
#include <string>
#include <type_traits>
#include <utility>
@@ -133,34 +134,41 @@ struct FmhaFwdJengaKernel
};
// std::variant<> can't take in a list initializer, overload for backward compatibility
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* block_relation_onehot_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
// 256-bool LDS staging caps N_k <= 256 (for kN0=64 -> seqlen_k <= 16384).
// Not constexpr because the assert needs runtime evaluation.
CK_TILE_HOST static Kargs MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* block_relation_onehot_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
{
// 256-bool LDS staging caps N_k <= 256 per Q-tile.
// For kN0=64 this means seqlen_k <= 16384.
assert(ck_tile::integer_divide_ceil(seqlen_k, FmhaPipeline::kN0) <= 256 &&
"256-bool LDS staging caps N_k <= 256 (for kN0=64: seqlen_k <= 16384)");
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
@@ -248,7 +256,11 @@ struct FmhaFwdJengaKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
// Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads.
// Extra LDS stages 256 bools (4B-aligned for LDS loads) — caps N_k <= 256 per Q-tile,
// i.e. seqlen_k <= 256 * kN0 (for kN0=64 -> seqlen_k <= 16384). MakeKargs asserts this.
// The extra 1024B is jenga-specific: pipeline (block_fmha_pipeline_qr_ks_vs_async_jenga
// .hpp:261) stages block_relation_onehot here. Do NOT copy this `+ 256*sizeof(int)` to
// other sparse kernels (e.g. VSA) without first wiring a real reader.
__shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)];
// if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d",

View File

@@ -251,8 +251,7 @@ struct FmhaFwdVSAKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
// Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads.
__shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)];
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);

View File

@@ -22,7 +22,7 @@ struct SpargeBlockMapKernel
static constexpr index_t kN0 = Pipeline::kN0;
static constexpr index_t D = Pipeline::D;
static constexpr index_t kAlignment = 16 / sizeof(QDataType);
static constexpr index_t kAlignment = 16 / sizeof(QDataType); // 16B = dwordx4 load width
struct Kargs
{
@@ -52,19 +52,18 @@ struct SpargeBlockMapKernel
void* lut_ptr;
void* valid_block_num_ptr;
// R20 K-stat workspace from Kernel A
const void* pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] fp32
const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8
// K-block stats workspace produced by SpargeKStatsKernel
const void*
pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype)
const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8
index_t N_k;
// R21A Phase 4: optional per-head topk (size = nhead_q floats).
// nullptr => use scalar `topk` for all heads.
// Per-head topk (size = nhead_q floats). Required (non-null).
const float* topk_per_head;
// R21B: optional per-head cdfthreshd (size = nhead_q floats).
// nullptr => use scalar `cdfthreshd` for all heads.
// Only consulted on topk<=0 path; bench currently always uses topk path.
// Per-head cdfthreshd (size = nhead_q floats). Required (non-null);
// only consulted on topk<=0 path.
const float* cdfthreshd_per_head;
};
@@ -90,8 +89,8 @@ struct SpargeBlockMapKernel
void* valid_block_num_ptr,
const void* pooled_k_ws_ptr,
const void* sim_k_ws_ptr,
const float* topk_per_head = nullptr,
const float* cdfthreshd_per_head = nullptr)
const float* topk_per_head,
const float* cdfthreshd_per_head)
{
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
return Kargs{q_ptr,
@@ -195,20 +194,15 @@ struct SpargeBlockMapKernel
// Shared memory
__shared__ char smem[Pipeline::GetSmemSize()];
// R20 K-stat workspace: pre-offset for this (b, hk).
const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk;
// K-stat workspace: pre-offset for this (b, hk).
const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk;
const index_t khead_off = (b * nhead_k + hk) * N_k;
const auto* pooled_k_ws =
reinterpret_cast<const float*>(kargs.pooled_k_ws_ptr) + khead_off * D;
const auto* sim_k_ws =
reinterpret_cast<const uint8_t*>(kargs.sim_k_ws_ptr) + khead_off;
reinterpret_cast<const KDataType*>(kargs.pooled_k_ws_ptr) + khead_off * D;
const auto* sim_k_ws = reinterpret_cast<const uint8_t*>(kargs.sim_k_ws_ptr) + khead_off;
// R21A Phase 4: per-head topk if provided, else scalar broadcast.
const float topk_eff =
(kargs.topk_per_head != nullptr) ? kargs.topk_per_head[hq] : kargs.topk;
// R21B: per-head cdfthreshd if provided, else scalar broadcast.
const float cdfthreshd_eff =
(kargs.cdfthreshd_per_head != nullptr) ? kargs.cdfthreshd_per_head[hq] : kargs.cdfthreshd;
const float topk_eff = kargs.topk_per_head[hq];
const float cdfthreshd_eff = kargs.cdfthreshd_per_head[hq];
Pipeline{}(q_window,
k_window,

View File

@@ -40,15 +40,13 @@ struct SpargeKStatsKernel
float simthreshd1;
void* pooled_k_ptr; // [batch, nhead_k, N_k, D] fp32
void* pooled_k_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype)
void* sim_k_ptr; // [batch, nhead_k, N_k] uint8
index_t N_k;
// R21A Phase 4 + R21B fix: optional per-head simthreshd1.
// Buffer is sized [nhead_q] floats to match SpargeAttn upstream contract
// (utils.py:324, Headnum=q.size(1)). Kernel only indexes the first
// nhead_k entries via [hk]. nullptr => use scalar `simthreshd1`.
// Per-head simthreshd1 pointer (size = nhead_q floats; kernel indexes [hk] only).
// Required (non-null); matches SpargeAttn upstream contract.
const float* simthreshd1_per_head;
};
@@ -62,7 +60,7 @@ struct SpargeKStatsKernel
float simthreshd1,
void* pooled_k_ptr,
void* sim_k_ptr,
const float* simthreshd1_per_head = nullptr)
const float* simthreshd1_per_head)
{
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
return Kargs{k_ptr,
@@ -111,17 +109,15 @@ struct SpargeKStatsKernel
{0, 0},
Pipeline::MakeKBlockDistribution());
const index_t N_k = kargs.N_k;
const index_t khead_off = (b * kargs.nhead_k + hk) * N_k;
auto* pooled_k_out = reinterpret_cast<float*>(kargs.pooled_k_ptr) + (khead_off + kb) * D;
auto* sim_k_out = reinterpret_cast<uint8_t*>(kargs.sim_k_ptr) + (khead_off + kb);
const index_t N_k = kargs.N_k;
const index_t khead_off = (b * kargs.nhead_k + hk) * N_k;
auto* pooled_k_out =
reinterpret_cast<KDataType*>(kargs.pooled_k_ptr) + (khead_off + kb) * D;
auto* sim_k_out = reinterpret_cast<uint8_t*>(kargs.sim_k_ptr) + (khead_off + kb);
__shared__ char smem[Pipeline::GetSmemSize()];
// R21A Phase 4: per-head simthreshd1 if provided, else scalar broadcast.
const float simthreshd1_eff = (kargs.simthreshd1_per_head != nullptr)
? kargs.simthreshd1_per_head[hk]
: kargs.simthreshd1;
const float simthreshd1_eff = kargs.simthreshd1_per_head[hk];
Pipeline{}(k_window,
kargs.seqlen_k,

View File

@@ -262,7 +262,7 @@ struct SpargeBlockMapPipeline
uint8_t* block_map_ptr,
int32_t* lut_ptr,
int32_t* valid_block_num_ptr,
const float* __restrict__ pooled_k_ws_ptr,
const KDataType* __restrict__ pooled_k_ws_ptr,
const uint8_t* __restrict__ sim_k_ws_ptr,
void* smem_ptr) const
{
@@ -356,10 +356,10 @@ struct SpargeBlockMapPipeline
for(index_t kb = 0; kb < N_k; ++kb)
{
const float* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
const KDataType* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
float pooled_k_mean[KPerThread];
for(index_t k = 0; k < KPerThread; ++k)
pooled_k_mean[k] = p_kb[k];
pooled_k_mean[k] = type_convert<float>(p_kb[k]);
float dot = 0.f;
for(index_t k = 0; k < KPerThread; ++k)
@@ -417,8 +417,7 @@ struct SpargeBlockMapPipeline
// cdfthreshd path (topk <= 0) still requires normalised scores so the
// accumulator `cumulative_prob` matches probabilities.
const bool topk_active = (topk > 0.f);
const float inv_sum =
(!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
const float inv_sum = (!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
if(!topk_active)
{
for(index_t i = tid; i < N_k; i += kBlockSize)

View File

@@ -49,8 +49,8 @@ struct SpargeKStatsPipeline
index_t seqlen_k,
index_t kb,
float simthreshd1,
float* __restrict__ pooled_k_out, // D floats
uint8_t* __restrict__ sim_k_out, // 1 byte
KDataType* __restrict__ pooled_k_out, // D KDataType (fp16/bf16)
uint8_t* __restrict__ sim_k_out, // 1 byte
void* smem_ptr) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
@@ -70,19 +70,19 @@ struct SpargeKStatsPipeline
const index_t m_idx = lane_id / KThreads;
// pooled_k_mean: column sum then cross-warp reduce.
// R21A: drop trailing sync (next cross_warp_reduce has its own leading sync).
// Drop trailing sync (next cross_warp_reduce has its own leading sync).
float pooled_k_mean[KPerThread];
Base::template column_reduce_thread_and_warp<NPerThread>(k_data, pooled_k_mean);
Base::template column_reduce_cross_warp<false>(pooled_k_mean, smem_reduce);
for(index_t k = 0; k < KPerThread; ++k)
pooled_k_mean[k] *= inv_bs_k;
// R21A: write pooled_k_mean to global early so its register liveness ends here,
// Write pooled_k_mean to global early so its register liveness ends here,
// freeing VGPR before k_sum_hat becomes live.
if(warp_id == 0 && m_idx == 0)
{
for(index_t k = 0; k < KPerThread; ++k)
pooled_k_out[k_idx * KPerThread + k] = pooled_k_mean[k];
pooled_k_out[k_idx * KPerThread + k] = type_convert<KDataType>(pooled_k_mean[k]);
}
// K row L2 norms + normalised column sum (k_sum_hat)
@@ -91,7 +91,7 @@ struct SpargeKStatsPipeline
float k_sum_hat[KPerThread];
Base::template column_reduce_normalised<NPerThread>(k_data, k_psq, k_sum_hat, bs_k);
// R21A: drop trailing sync (no further smem read; only intra-warp shuffle + global write).
// Drop trailing sync (no further smem read; only intra-warp shuffle + global write).
Base::template column_reduce_cross_warp<false>(k_sum_hat, smem_reduce);
// sim_k = (||k_sum_hat||^2 / bs_k^2) > simthreshd1