mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
refactor(sparse_attn): caller-owned workspace + dtype-aware sizing
Replace process-lifetime lazy hipMalloc K-stats workspace with a caller-owned buffer; expose sparge_blockmap_get_workspace_size() / compute_workspace_layout() host helpers. Split the combined sparge_blockmap_fwd into stage launchers (sparge_kstats_fwd_oneshot + sparge_blockmap_only_fwd_oneshot) so the chained launch is timed end-to-end. Make pooled_k storage dtype follow KDataType (fp16/bf16) instead of fp32 to halve workspace footprint and match dense-FMHA precision. Tighten per-head superparam pointers to required (non-null) and assert N_k <= 256 in jenga MakeKargs to document the 256-bool LDS staging cap. Drop the obsolete VSA extra-LDS staging. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
@@ -67,47 +67,24 @@ using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel<bmap_fp16_pipeline>;
|
||||
using kstats_fp16_pipeline = ck_tile::SpargeKStatsPipeline<bmap_fp16_problem>;
|
||||
using kstats_fp16_kernel = ck_tile::SpargeKStatsKernel<kstats_fp16_pipeline>;
|
||||
|
||||
// ============================================================================
|
||||
// bf16: D=128, kM0=64, kN0=128
|
||||
// ============================================================================
|
||||
// bf16: dtype-independent aliases share fp16 chain; only problem differs.
|
||||
using bmap_bf16_block_tile = bmap_fp16_block_tile;
|
||||
using bmap_bf16_shape = bmap_fp16_shape;
|
||||
using bmap_bf16_trait = bmap_fp16_trait;
|
||||
using bmap_bf16_variant = bmap_fp16_variant;
|
||||
using bmap_bf16_mask = bmap_fp16_mask;
|
||||
|
||||
using bmap_bf16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>;
|
||||
|
||||
using bmap_bf16_shape =
|
||||
ck_tile::TileFmhaShape<bmap_bf16_block_tile,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
ck_tile::sequence<16, 16, 16>,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
ck_tile::sequence<16, 16, 16>,
|
||||
true>;
|
||||
|
||||
using bmap_bf16_trait = ck_tile::TileFmhaTraits<true, // kPadSeqLenQ
|
||||
true, // kPadSeqLenK
|
||||
true, // kPadHeadDimQ
|
||||
true, // kPadHeadDimV
|
||||
false, // kHasLogitsSoftCap
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false, // kStoreLSE
|
||||
false, // kHasDropout
|
||||
false, // kHasRandVal
|
||||
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE,
|
||||
-1,
|
||||
false>;
|
||||
|
||||
using bmap_bf16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
using bmap_bf16_mask = ck_tile::GenericAttentionMask<false>;
|
||||
|
||||
using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::bf16_t, // QDataType
|
||||
ck_tile::bf16_t, // KDataType
|
||||
ck_tile::bf16_t, // VDataType
|
||||
float, // SaccDataType
|
||||
float, // SMPLComputeDataType
|
||||
ck_tile::bf16_t, // BiasDataType
|
||||
uint8_t, // RandValOutputDataType
|
||||
float, // LSEDataType
|
||||
ck_tile::bf16_t, // PDataType
|
||||
float, // OaccDataType
|
||||
ck_tile::bf16_t, // ODataType
|
||||
using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::bf16_t, // QDataType
|
||||
ck_tile::bf16_t, // KDataType
|
||||
ck_tile::bf16_t, // VDataType
|
||||
float, // SaccDataType
|
||||
float, // SMPLComputeDataType
|
||||
ck_tile::bf16_t, // BiasDataType
|
||||
uint8_t, // RandValOutputDataType
|
||||
float, // LSEDataType
|
||||
ck_tile::bf16_t, // PDataType
|
||||
float, // OaccDataType
|
||||
ck_tile::bf16_t, // ODataType
|
||||
bmap_bf16_shape,
|
||||
false, // kIsGroupMode
|
||||
bmap_bf16_variant,
|
||||
@@ -122,168 +99,168 @@ using kstats_bf16_pipeline = ck_tile::SpargeKStatsPipeline<bmap_bf16_problem>;
|
||||
using kstats_bf16_kernel = ck_tile::SpargeKStatsKernel<kstats_bf16_pipeline>;
|
||||
|
||||
// ============================================================================
|
||||
// Internal K-stat workspace (R20): process-lifetime lazy hipMalloc, sized
|
||||
// to the largest (batch, nhead_k, N_k, D) seen so far. Caller API unchanged.
|
||||
// Workspace layout: caller owns the buffer; we just compute size + offsets.
|
||||
// Layout = [pooled_k (KDataType) | sim_k (uint8)]. sim_k follows pooled_k with
|
||||
// no padding (uint8 has alignment 1).
|
||||
// ============================================================================
|
||||
|
||||
namespace {
|
||||
|
||||
struct KStatsWorkspace
|
||||
constexpr int sparge_kN0_for(int hdim_q)
|
||||
{
|
||||
void* pooled_k_dev = nullptr; // [batch, nhead_k, N_k, D] fp32
|
||||
void* sim_k_dev = nullptr; // [batch, nhead_k, N_k] uint8
|
||||
size_t pooled_k_bytes = 0;
|
||||
size_t sim_k_bytes = 0;
|
||||
|
||||
void ensure(int batch, int nhead_k, int N_k, int D)
|
||||
{
|
||||
const size_t need_p = static_cast<size_t>(batch) * nhead_k * N_k * D * sizeof(float);
|
||||
const size_t need_s = static_cast<size_t>(batch) * nhead_k * N_k * sizeof(uint8_t);
|
||||
if(need_p > pooled_k_bytes)
|
||||
{
|
||||
if(pooled_k_dev != nullptr) (void)hipFree(pooled_k_dev);
|
||||
(void)hipMalloc(&pooled_k_dev, need_p);
|
||||
pooled_k_bytes = need_p;
|
||||
}
|
||||
if(need_s > sim_k_bytes)
|
||||
{
|
||||
if(sim_k_dev != nullptr) (void)hipFree(sim_k_dev);
|
||||
(void)hipMalloc(&sim_k_dev, need_s);
|
||||
sim_k_bytes = need_s;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
KStatsWorkspace& g_kstats_ws()
|
||||
{
|
||||
static KStatsWorkspace ws;
|
||||
return ws;
|
||||
// d=128 instances use kN0=128 (see bmap_fp16_block_tile).
|
||||
return (hdim_q == 128) ? 128 : 0;
|
||||
}
|
||||
|
||||
template <typename KStatsKernel, typename BlockMapKernel>
|
||||
void launch_kstats_then_blockmap(sparge_blockmap_args args, const ck_tile::stream_config& s)
|
||||
size_t dtype_bytes(const std::string& dt)
|
||||
{
|
||||
const int N_k = ck_tile::integer_divide_ceil(args.seqlen_k, BlockMapKernel::kN0);
|
||||
const int D = BlockMapKernel::D;
|
||||
auto& ws = g_kstats_ws();
|
||||
ws.ensure(args.batch, args.nhead_k, N_k, D);
|
||||
if(dt == "fp16" || dt == "bf16")
|
||||
return 2;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Stage 1: K stats
|
||||
{
|
||||
auto [kargs, grids] =
|
||||
sparge_kstats_create_kargs_and_grids<KStatsKernel>(args, ws.pooled_k_dev, ws.sim_k_dev);
|
||||
const dim3 blocks = KStatsKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = KStatsKernel::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(KStatsKernel{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
// Stage 2: block_map (reads ws)
|
||||
{
|
||||
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<BlockMapKernel>(
|
||||
args, ws.pooled_k_dev, ws.sim_k_dev);
|
||||
const dim3 blocks = BlockMapKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = BlockMapKernel::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(BlockMapKernel{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
sparge_blockmap_workspace_layout
|
||||
sparge_blockmap_compute_workspace_layout(sparge_blockmap_traits traits, sparge_blockmap_args args)
|
||||
{
|
||||
const int kN0 = sparge_kN0_for(traits.hdim_q);
|
||||
const int N_k = (kN0 > 0) ? ck_tile::integer_divide_ceil(args.seqlen_k, kN0) : 0;
|
||||
const int D = traits.hdim_q;
|
||||
const size_t element_bytes = dtype_bytes(traits.data_type);
|
||||
|
||||
sparge_blockmap_workspace_layout layout{};
|
||||
layout.pooled_k_offset = 0;
|
||||
layout.pooled_k_bytes =
|
||||
static_cast<size_t>(args.batch) * args.nhead_k * N_k * D * element_bytes;
|
||||
layout.sim_k_offset = layout.pooled_k_bytes;
|
||||
layout.sim_k_bytes = static_cast<size_t>(args.batch) * args.nhead_k * N_k * sizeof(uint8_t);
|
||||
layout.total_bytes = layout.sim_k_offset + layout.sim_k_bytes;
|
||||
return layout;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Stage launchers: read args.workspace_ptr split per layout, run one kernel.
|
||||
// ============================================================================
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename KStatsKernel>
|
||||
void launch_kstats_only(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
const auto layout = sparge_blockmap_compute_workspace_layout(traits, args);
|
||||
auto* ws_base = static_cast<char*>(args.workspace_ptr);
|
||||
void* pooled_k_ptr = ws_base + layout.pooled_k_offset;
|
||||
void* sim_k_ptr = ws_base + layout.sim_k_offset;
|
||||
|
||||
auto [kargs, grids] =
|
||||
sparge_kstats_create_kargs_and_grids<KStatsKernel>(args, pooled_k_ptr, sim_k_ptr);
|
||||
const dim3 blocks = KStatsKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = KStatsKernel::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(KStatsKernel{}, grids, blocks, 0, kargs)(s);
|
||||
}
|
||||
|
||||
template <typename BlockMapKernel>
|
||||
void launch_blockmap_only(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
const auto layout = sparge_blockmap_compute_workspace_layout(traits, args);
|
||||
auto* ws_base = static_cast<char*>(args.workspace_ptr);
|
||||
void* pooled_k_ptr = ws_base + layout.pooled_k_offset;
|
||||
void* sim_k_ptr = ws_base + layout.sim_k_offset;
|
||||
|
||||
auto [kargs, grids] =
|
||||
sparge_blockmap_create_kargs_and_grids<BlockMapKernel>(args, pooled_k_ptr, sim_k_ptr);
|
||||
const dim3 blocks = BlockMapKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = BlockMapKernel::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(BlockMapKernel{}, grids, blocks, 0, kargs)(s);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// ============================================================================
|
||||
// Dispatch
|
||||
// Oneshot stages (no timing): caller chains them via launch_kernel.
|
||||
// ============================================================================
|
||||
|
||||
float sparge_blockmap_fwd(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& s)
|
||||
void sparge_kstats_fwd_oneshot(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(traits.data_type == "fp16" && traits.hdim_q == 128)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", sparge_blockmap_fp16_d128" << std::flush;
|
||||
return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) {
|
||||
launch_kstats_then_blockmap<kstats_fp16_kernel, bmap_fp16_kernel>(args, s_);
|
||||
});
|
||||
}
|
||||
|
||||
if(traits.data_type == "bf16" && traits.hdim_q == 128)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", sparge_blockmap_bf16_d128" << std::flush;
|
||||
return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) {
|
||||
launch_kstats_then_blockmap<kstats_bf16_kernel, bmap_bf16_kernel>(args, s_);
|
||||
});
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type
|
||||
<< ", hdim_q=" << traits.hdim_q << ")" << std::endl;
|
||||
return -1.f;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Oneshot version: launches kernel without timing wrapper
|
||||
// ============================================================================
|
||||
|
||||
void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(traits.data_type == "fp16" && traits.hdim_q == 128)
|
||||
{
|
||||
launch_kstats_then_blockmap<kstats_fp16_kernel, bmap_fp16_kernel>(args, s);
|
||||
launch_kstats_only<kstats_fp16_kernel>(traits, args, s);
|
||||
return;
|
||||
}
|
||||
|
||||
if(traits.data_type == "bf16" && traits.hdim_q == 128)
|
||||
{
|
||||
launch_kstats_then_blockmap<kstats_bf16_kernel, bmap_bf16_kernel>(args, s);
|
||||
launch_kstats_only<kstats_bf16_kernel>(traits, args, s);
|
||||
return;
|
||||
}
|
||||
|
||||
std::cerr << "sparge_blockmap_fwd_oneshot: unsupported config (data_type=" << traits.data_type
|
||||
std::cerr << "sparge_kstats_fwd_oneshot: unsupported config (data_type=" << traits.data_type
|
||||
<< ", hdim_q=" << traits.hdim_q << ")" << std::endl;
|
||||
}
|
||||
|
||||
void sparge_blockmap_only_fwd_oneshot(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(traits.data_type == "fp16" && traits.hdim_q == 128)
|
||||
{
|
||||
launch_blockmap_only<bmap_fp16_kernel>(traits, args, s);
|
||||
return;
|
||||
}
|
||||
if(traits.data_type == "bf16" && traits.hdim_q == 128)
|
||||
{
|
||||
launch_blockmap_only<bmap_bf16_kernel>(traits, args, s);
|
||||
return;
|
||||
}
|
||||
std::cerr << "sparge_blockmap_only_fwd_oneshot: unsupported config (data_type="
|
||||
<< traits.data_type << ", hdim_q=" << traits.hdim_q << ")" << std::endl;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Combined functions: blockmap + attention timed together via launch_kernel
|
||||
// Combined functions: kstats + blockmap + attention timed together.
|
||||
// ============================================================================
|
||||
|
||||
float sparge_jenga_fwd(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a,
|
||||
fmha_jenga_fwd_traits attn_t, fmha_jenga_fwd_args attn_a,
|
||||
float sparge_jenga_fwd(sparge_blockmap_traits bmap_t,
|
||||
sparge_blockmap_args bmap_a,
|
||||
fmha_jenga_fwd_traits attn_t,
|
||||
fmha_jenga_fwd_args attn_a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
|
||||
<< ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q
|
||||
<< std::flush;
|
||||
std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
|
||||
<< ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
|
||||
<< ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
[=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); },
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_);
|
||||
sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_);
|
||||
},
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
fmha_jenga_fwd_oneshot(attn_t, attn_a, s_);
|
||||
});
|
||||
[=](const ck_tile::stream_config& s_) { fmha_jenga_fwd_oneshot(attn_t, attn_a, s_); });
|
||||
}
|
||||
|
||||
float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a,
|
||||
fmha_vsa_fwd_traits attn_t, fmha_vsa_fwd_args attn_a,
|
||||
float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t,
|
||||
sparge_blockmap_args bmap_a,
|
||||
fmha_vsa_fwd_traits attn_t,
|
||||
fmha_vsa_fwd_args attn_a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
|
||||
<< ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q
|
||||
<< std::flush;
|
||||
std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
|
||||
<< ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q
|
||||
<< ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
[=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); },
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_);
|
||||
sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_);
|
||||
},
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
fmha_vsa_fwd_oneshot(attn_t, attn_a, s_);
|
||||
});
|
||||
[=](const ck_tile::stream_config& s_) { fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); });
|
||||
}
|
||||
|
||||
@@ -48,14 +48,23 @@ struct sparge_blockmap_args
|
||||
void* lut_ptr;
|
||||
void* valid_block_num_ptr;
|
||||
|
||||
// R21A Phase 4 + R21B fix: optional per-head superparams. nullptr => use scalar.
|
||||
// Buffer sizes match SpargeAttn upstream contract (utils.py:324-328: all sized
|
||||
// by Headnum=q.size(1)=nhead_q). K-side kernel still indexes [hk] into the
|
||||
// first nhead_k entries — for MHA equivalent to old [nhead_k] sizing, for
|
||||
// MQA/GQA aligns to upstream tuned ckpt layout.
|
||||
const float* simthreshd1_per_head_ptr = nullptr; // size = nhead_q floats (kernel reads [0..nhead_k-1])
|
||||
const float* cdfthreshd_per_head_ptr = nullptr; // size = nhead_q floats
|
||||
const float* topk_per_head_ptr = nullptr; // size = nhead_q floats
|
||||
// Caller-owned K-stats workspace; size from sparge_blockmap_get_workspace_size.
|
||||
// Internal layout (pooled_k then sim_k) given by sparge_blockmap_workspace_layout.
|
||||
void* workspace_ptr = nullptr;
|
||||
|
||||
// size = nhead_q to match SpargeAttn upstream hyperparameter_check
|
||||
const float* simthreshd1_per_head_ptr = nullptr;
|
||||
const float* cdfthreshd_per_head_ptr = nullptr;
|
||||
const float* topk_per_head_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct sparge_blockmap_workspace_layout
|
||||
{
|
||||
size_t pooled_k_offset; // bytes from workspace_ptr
|
||||
size_t pooled_k_bytes;
|
||||
size_t sim_k_offset; // bytes from workspace_ptr
|
||||
size_t sim_k_bytes;
|
||||
size_t total_bytes;
|
||||
};
|
||||
|
||||
struct sparge_blockmap_traits
|
||||
@@ -127,19 +136,36 @@ auto sparge_kstats_create_kargs_and_grids(sparge_blockmap_args args,
|
||||
// ============================================================================
|
||||
// Hand-written template instantiation dispatch
|
||||
// ============================================================================
|
||||
float sparge_blockmap_fwd(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& stream_config);
|
||||
|
||||
void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& stream_config);
|
||||
// Workspace sizing helpers (host, no template instantiation needed).
|
||||
sparge_blockmap_workspace_layout
|
||||
sparge_blockmap_compute_workspace_layout(sparge_blockmap_traits traits, sparge_blockmap_args args);
|
||||
|
||||
// Combined functions: blockmap + attention with unified timing
|
||||
float sparge_jenga_fwd(sparge_blockmap_traits, sparge_blockmap_args,
|
||||
fmha_jenga_fwd_traits, fmha_jenga_fwd_args,
|
||||
inline size_t sparge_blockmap_get_workspace_size(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args)
|
||||
{
|
||||
return sparge_blockmap_compute_workspace_layout(traits, args).total_bytes;
|
||||
}
|
||||
|
||||
// Stage 1: K-stats only. Writes pooled_k + sim_k into args.workspace_ptr.
|
||||
void sparge_kstats_fwd_oneshot(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& stream_config);
|
||||
|
||||
// Stage 2: block_map only. Reads pooled_k + sim_k from args.workspace_ptr.
|
||||
void sparge_blockmap_only_fwd_oneshot(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& stream_config);
|
||||
|
||||
// Combined functions: kstats + blockmap + attention with unified timing.
|
||||
float sparge_jenga_fwd(sparge_blockmap_traits,
|
||||
sparge_blockmap_args,
|
||||
fmha_jenga_fwd_traits,
|
||||
fmha_jenga_fwd_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
float sparge_vsa_fwd_combined(sparge_blockmap_traits, sparge_blockmap_args,
|
||||
fmha_vsa_fwd_traits, fmha_vsa_fwd_args,
|
||||
float sparge_vsa_fwd_combined(sparge_blockmap_traits,
|
||||
sparge_blockmap_args,
|
||||
fmha_vsa_fwd_traits,
|
||||
fmha_vsa_fwd_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
@@ -25,8 +25,11 @@
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T>
|
||||
make_qkv_tensor(ck_tile::index_t batch, ck_tile::index_t nhead, ck_tile::index_t seqlen, ck_tile::index_t hdim, bool i_perm)
|
||||
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim,
|
||||
bool i_perm)
|
||||
{
|
||||
if(i_perm)
|
||||
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
|
||||
@@ -86,8 +89,7 @@ float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser
|
||||
.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("pipeline", "jenga", "attention pipeline: jenga / vsa")
|
||||
.insert("b", "1", "batch size")
|
||||
.insert("h", "4", "num of head for q")
|
||||
@@ -105,10 +107,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name")
|
||||
.insert("perhead", "0",
|
||||
"R21A Phase 4: 0=scalar (default), 1=per-head [H] superparam test "
|
||||
"(varies topk[h] = topk * (1 + 0.5*(h - H/2)/H), simthreshd1 unchanged)");
|
||||
.insert("kname", "0", "print kernel name");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -120,29 +119,31 @@ auto create_args(int argc, char* argv[])
|
||||
template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
std::string pipeline = arg_parser.get_str("pipeline");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
float topk = arg_parser.get_float("topk");
|
||||
float cdfthreshd = arg_parser.get_float("cdfthreshd");
|
||||
float simthreshd1 = arg_parser.get_float("simthreshd1");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int perhead = arg_parser.get_int("perhead");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
std::string pipeline = arg_parser.get_str("pipeline");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
float topk = arg_parser.get_float("topk");
|
||||
float cdfthreshd = arg_parser.get_float("cdfthreshd");
|
||||
float simthreshd1 = arg_parser.get_float("simthreshd1");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
|
||||
if(nhead_k < 0) nhead_k = nhead;
|
||||
if(seqlen_k < 0) seqlen_k = seqlen_q;
|
||||
if(hdim_v < 0) hdim_v = hdim_q;
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
// If cdfthreshd >= 0, use CDF mode; otherwise use topk mode
|
||||
if(cdfthreshd >= 0.0f)
|
||||
@@ -162,15 +163,14 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
|
||||
|
||||
std::string prec_str = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
|
||||
std::cout << "[" << pipeline << "|" << prec_str
|
||||
<< "] b=" << batch << " h=" << nhead << " s=" << seqlen_q
|
||||
<< " d=" << hdim_q << " topk=" << topk
|
||||
<< " sim1=" << simthreshd1 << std::flush;
|
||||
std::cout << "[" << pipeline << "|" << prec_str << "] b=" << batch << " h=" << nhead
|
||||
<< " s=" << seqlen_q << " d=" << hdim_q << " topk=" << topk << " sim1=" << simthreshd1
|
||||
<< std::flush;
|
||||
|
||||
// ---- allocate host tensors ----
|
||||
auto q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
auto k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
auto v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
auto q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
auto k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
auto v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
auto output_host = o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
|
||||
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
|
||||
|
||||
@@ -213,62 +213,61 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
bmap_traits.hdim_q = hdim_q;
|
||||
|
||||
sparge_blockmap_args bmap_args;
|
||||
bmap_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
bmap_args.k_ptr = k_dev.GetDeviceBuffer();
|
||||
bmap_args.batch = batch;
|
||||
bmap_args.seqlen_q = seqlen_q;
|
||||
bmap_args.seqlen_k = seqlen_k;
|
||||
bmap_args.hdim_q = hdim_q;
|
||||
bmap_args.nhead_q = nhead;
|
||||
bmap_args.nhead_k = nhead_k;
|
||||
bmap_args.stride_q = q_strides[i_perm ? 2 : 1];
|
||||
bmap_args.stride_k = k_strides[i_perm ? 2 : 1];
|
||||
bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
|
||||
bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
|
||||
bmap_args.batch_stride_q = q_strides[0];
|
||||
bmap_args.batch_stride_k = k_strides[0];
|
||||
bmap_args.simthreshd1 = simthreshd1;
|
||||
bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f;
|
||||
bmap_args.topk = topk;
|
||||
bmap_args.scale = scale_s;
|
||||
bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer();
|
||||
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
|
||||
bmap_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
bmap_args.k_ptr = k_dev.GetDeviceBuffer();
|
||||
bmap_args.batch = batch;
|
||||
bmap_args.seqlen_q = seqlen_q;
|
||||
bmap_args.seqlen_k = seqlen_k;
|
||||
bmap_args.hdim_q = hdim_q;
|
||||
bmap_args.nhead_q = nhead;
|
||||
bmap_args.nhead_k = nhead_k;
|
||||
bmap_args.stride_q = q_strides[i_perm ? 2 : 1];
|
||||
bmap_args.stride_k = k_strides[i_perm ? 2 : 1];
|
||||
bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
|
||||
bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
|
||||
bmap_args.batch_stride_q = q_strides[0];
|
||||
bmap_args.batch_stride_k = k_strides[0];
|
||||
bmap_args.simthreshd1 = simthreshd1;
|
||||
bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f;
|
||||
bmap_args.topk = topk;
|
||||
bmap_args.scale = scale_s;
|
||||
bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer();
|
||||
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
|
||||
bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr;
|
||||
|
||||
// R21A Phase 4 + R21B fix: per-head superparam buffers, all sized [nhead_q]
|
||||
// to match SpargeAttn upstream contract (utils.py:324-328, Headnum=q.size(1)).
|
||||
// K-stats workspace: caller-owned, sized via host helper, allocated once outside any timing.
|
||||
const size_t ws_bytes = sparge_blockmap_get_workspace_size(bmap_traits, bmap_args);
|
||||
ck_tile::DeviceMem kstats_ws_dev(ws_bytes);
|
||||
bmap_args.workspace_ptr = kstats_ws_dev.GetDeviceBuffer();
|
||||
|
||||
// Per-head superparam buffers, all sized [nhead_q] to match SpargeAttn upstream contract.
|
||||
// K-side kernel reads only the first nhead_k entries via [hk].
|
||||
// Filled with scalar broadcast; per-head index correctness verified by separate unit test.
|
||||
ck_tile::DeviceMem topk_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
|
||||
ck_tile::DeviceMem sim1_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
|
||||
ck_tile::DeviceMem cdf_per_head_dev (static_cast<size_t>(nhead) * sizeof(float));
|
||||
if(perhead != 0)
|
||||
ck_tile::DeviceMem cdf_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
|
||||
{
|
||||
std::vector<float> topk_h(nhead);
|
||||
std::vector<float> sim1_h(nhead);
|
||||
std::vector<float> cdf_h (nhead);
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
// small per-head jitter around scalar topk so sparsity differs by head
|
||||
const float jitter = 0.5f * (static_cast<float>(h - nhead / 2) / nhead);
|
||||
topk_h[h] = topk * (1.0f + jitter);
|
||||
sim1_h[h] = simthreshd1; // bit-identical to scalar (kernel reads [0..nhead_k-1])
|
||||
cdf_h[h] = cdfthreshd;
|
||||
}
|
||||
std::vector<float> topk_h(nhead, topk);
|
||||
std::vector<float> sim1_h(nhead, simthreshd1);
|
||||
std::vector<float> cdf_h(nhead, cdfthreshd);
|
||||
topk_per_head_dev.ToDevice(topk_h.data());
|
||||
sim1_per_head_dev.ToDevice(sim1_h.data());
|
||||
cdf_per_head_dev .ToDevice(cdf_h.data());
|
||||
bmap_args.topk_per_head_ptr = static_cast<const float*>(topk_per_head_dev.GetDeviceBuffer());
|
||||
bmap_args.simthreshd1_per_head_ptr = static_cast<const float*>(sim1_per_head_dev.GetDeviceBuffer());
|
||||
bmap_args.cdfthreshd_per_head_ptr = static_cast<const float*>(cdf_per_head_dev.GetDeviceBuffer());
|
||||
cdf_per_head_dev.ToDevice(cdf_h.data());
|
||||
bmap_args.topk_per_head_ptr =
|
||||
static_cast<const float*>(topk_per_head_dev.GetDeviceBuffer());
|
||||
bmap_args.simthreshd1_per_head_ptr =
|
||||
static_cast<const float*>(sim1_per_head_dev.GetDeviceBuffer());
|
||||
bmap_args.cdfthreshd_per_head_ptr =
|
||||
static_cast<const float*>(cdf_per_head_dev.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
// ---- build attention args ----
|
||||
ck_tile::stream_config stream_cfg;
|
||||
stream_cfg.stream_id_ = nullptr;
|
||||
stream_cfg.stream_id_ = nullptr;
|
||||
stream_cfg.time_kernel_ = true;
|
||||
stream_cfg.log_level_ = kname;
|
||||
stream_cfg.log_level_ = kname;
|
||||
stream_cfg.cold_niters_ = warmup;
|
||||
stream_cfg.nrepeat_ = repeat;
|
||||
stream_cfg.nrepeat_ = repeat;
|
||||
|
||||
float avg_ms = -1.0f;
|
||||
|
||||
@@ -283,35 +282,35 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
attn_traits.bm0 = BLKQ;
|
||||
|
||||
fmha_jenga_fwd_args attn_args;
|
||||
attn_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
attn_args.k_ptr = k_dev.GetDeviceBuffer();
|
||||
attn_args.v_ptr = v_dev.GetDeviceBuffer();
|
||||
attn_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
attn_args.k_ptr = k_dev.GetDeviceBuffer();
|
||||
attn_args.v_ptr = v_dev.GetDeviceBuffer();
|
||||
attn_args.block_relation_onehot_ptr = block_map_dev.GetDeviceBuffer();
|
||||
attn_args.o_ptr = o_dev.GetDeviceBuffer();
|
||||
attn_args.seqlen_q = seqlen_q;
|
||||
attn_args.seqlen_k = seqlen_k;
|
||||
attn_args.batch = batch;
|
||||
attn_args.max_seqlen_q = seqlen_q;
|
||||
attn_args.hdim_q = hdim_q;
|
||||
attn_args.hdim_v = hdim_v;
|
||||
attn_args.nhead_q = nhead;
|
||||
attn_args.nhead_k = nhead_k;
|
||||
attn_args.scale_s = scale_s;
|
||||
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
|
||||
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
|
||||
attn_args.batch_stride_q = q_strides[0];
|
||||
attn_args.batch_stride_k = k_strides[0];
|
||||
attn_args.batch_stride_v = v_strides[0];
|
||||
attn_args.batch_stride_o = o_strides[0];
|
||||
attn_args.window_size_left = -1;
|
||||
attn_args.window_size_right = -1;
|
||||
attn_args.mask_type = 0;
|
||||
attn_args.o_ptr = o_dev.GetDeviceBuffer();
|
||||
attn_args.seqlen_q = seqlen_q;
|
||||
attn_args.seqlen_k = seqlen_k;
|
||||
attn_args.batch = batch;
|
||||
attn_args.max_seqlen_q = seqlen_q;
|
||||
attn_args.hdim_q = hdim_q;
|
||||
attn_args.hdim_v = hdim_v;
|
||||
attn_args.nhead_q = nhead;
|
||||
attn_args.nhead_k = nhead_k;
|
||||
attn_args.scale_s = scale_s;
|
||||
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
|
||||
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
|
||||
attn_args.batch_stride_q = q_strides[0];
|
||||
attn_args.batch_stride_k = k_strides[0];
|
||||
attn_args.batch_stride_v = v_strides[0];
|
||||
attn_args.batch_stride_o = o_strides[0];
|
||||
attn_args.window_size_left = -1;
|
||||
attn_args.window_size_right = -1;
|
||||
attn_args.mask_type = 0;
|
||||
|
||||
avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
|
||||
}
|
||||
@@ -326,38 +325,39 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
attn_traits.bm0 = BLKQ;
|
||||
|
||||
fmha_vsa_fwd_args attn_args;
|
||||
attn_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
attn_args.k_ptr = k_dev.GetDeviceBuffer();
|
||||
attn_args.v_ptr = v_dev.GetDeviceBuffer();
|
||||
attn_args.lut_ptr = lut_dev.GetDeviceBuffer();
|
||||
attn_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
attn_args.k_ptr = k_dev.GetDeviceBuffer();
|
||||
attn_args.v_ptr = v_dev.GetDeviceBuffer();
|
||||
attn_args.lut_ptr = lut_dev.GetDeviceBuffer();
|
||||
attn_args.valid_block_num_ptr = valid_bn_dev.GetDeviceBuffer();
|
||||
attn_args.o_ptr = o_dev.GetDeviceBuffer();
|
||||
attn_args.seqlen_q = seqlen_q;
|
||||
attn_args.seqlen_k = seqlen_k;
|
||||
attn_args.batch = batch;
|
||||
attn_args.max_seqlen_q = seqlen_q;
|
||||
attn_args.hdim_q = hdim_q;
|
||||
attn_args.hdim_v = hdim_v;
|
||||
attn_args.nhead_q = nhead;
|
||||
attn_args.nhead_k = nhead_k;
|
||||
attn_args.scale_s = scale_s;
|
||||
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
|
||||
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
|
||||
attn_args.batch_stride_q = q_strides[0];
|
||||
attn_args.batch_stride_k = k_strides[0];
|
||||
attn_args.batch_stride_v = v_strides[0];
|
||||
attn_args.batch_stride_o = o_strides[0];
|
||||
attn_args.window_size_left = -1;
|
||||
attn_args.window_size_right = -1;
|
||||
attn_args.mask_type = 0;
|
||||
attn_args.o_ptr = o_dev.GetDeviceBuffer();
|
||||
attn_args.seqlen_q = seqlen_q;
|
||||
attn_args.seqlen_k = seqlen_k;
|
||||
attn_args.batch = batch;
|
||||
attn_args.max_seqlen_q = seqlen_q;
|
||||
attn_args.hdim_q = hdim_q;
|
||||
attn_args.hdim_v = hdim_v;
|
||||
attn_args.nhead_q = nhead;
|
||||
attn_args.nhead_k = nhead_k;
|
||||
attn_args.scale_s = scale_s;
|
||||
attn_args.stride_q = q_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_k = k_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_v = v_strides[i_perm ? 2 : 1];
|
||||
attn_args.stride_o = o_strides[o_perm ? 2 : 1];
|
||||
attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2];
|
||||
attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2];
|
||||
attn_args.batch_stride_q = q_strides[0];
|
||||
attn_args.batch_stride_k = k_strides[0];
|
||||
attn_args.batch_stride_v = v_strides[0];
|
||||
attn_args.batch_stride_o = o_strides[0];
|
||||
attn_args.window_size_left = -1;
|
||||
attn_args.window_size_right = -1;
|
||||
attn_args.mask_type = 0;
|
||||
|
||||
avg_ms = sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
|
||||
avg_ms =
|
||||
sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -367,8 +367,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// ---- TFLOPS calculation (dense FMHA formula, so sparsity gains show as higher TFLOPS) ----
|
||||
std::size_t flop = static_cast<std::size_t>(batch) * nhead *
|
||||
(static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_q +
|
||||
static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_v);
|
||||
(static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_q +
|
||||
static_cast<std::size_t>(2) * seqlen_q * seqlen_k * hdim_v);
|
||||
float tflops = (avg_ms > 0.f) ? static_cast<float>(flop) / 1.E9f / avg_ms : 0.f;
|
||||
|
||||
if(avg_ms > 0.f)
|
||||
@@ -382,14 +382,15 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
block_map_dev.FromDevice(block_map_host.data());
|
||||
|
||||
// ---- count active blocks ----
|
||||
ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks;
|
||||
ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks;
|
||||
ck_tile::index_t active_blocks = 0;
|
||||
for(size_t i = 0; i < block_map_host.mData.size(); ++i)
|
||||
if(block_map_host.mData[i])
|
||||
active_blocks++;
|
||||
float actual_sparsity = 1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity
|
||||
<< "(" << active_blocks << "/" << total_blocks << ")" << std::flush;
|
||||
float actual_sparsity =
|
||||
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity << "(" << active_blocks
|
||||
<< "/" << total_blocks << ")" << std::flush;
|
||||
|
||||
// ---- validation ----
|
||||
bool pass = true;
|
||||
@@ -405,8 +406,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
auto [rtol, atol] = get_error_tolerance<T>();
|
||||
|
||||
float max_diff = 0.0f;
|
||||
size_t num_errors = 0;
|
||||
float max_diff = 0.0f;
|
||||
size_t num_errors = 0;
|
||||
|
||||
auto output_host_bhsd = to_bhsd(output_host, o_perm);
|
||||
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
|
||||
@@ -423,9 +424,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
pass = (num_errors == 0);
|
||||
std::cout << ", " << (pass ? "PASS" : "FAIL")
|
||||
<< "(err=" << num_errors << "/" << output_host_bhsd.mData.size()
|
||||
<< " maxdiff=" << max_diff << ")";
|
||||
std::cout << ", " << (pass ? "PASS" : "FAIL") << "(err=" << num_errors << "/"
|
||||
<< output_host_bhsd.mData.size() << " maxdiff=" << max_diff << ")";
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
@@ -133,34 +134,41 @@ struct FmhaFwdJengaKernel
|
||||
};
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* block_relation_onehot_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
// 256-bool LDS staging caps N_k <= 256 (for kN0=64 -> seqlen_k <= 16384).
|
||||
// Not constexpr because the assert needs runtime evaluation.
|
||||
CK_TILE_HOST static Kargs MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* block_relation_onehot_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
{
|
||||
// 256-bool LDS staging caps N_k <= 256 per Q-tile.
|
||||
// For kN0=64 this means seqlen_k <= 16384.
|
||||
assert(ck_tile::integer_divide_ceil(seqlen_k, FmhaPipeline::kN0) <= 256 &&
|
||||
"256-bool LDS staging caps N_k <= 256 (for kN0=64: seqlen_k <= 16384)");
|
||||
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
@@ -248,7 +256,11 @@ struct FmhaFwdJengaKernel
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
// allocate LDS
|
||||
// Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads.
|
||||
// Extra LDS stages 256 bools (4B-aligned for LDS loads) — caps N_k <= 256 per Q-tile,
|
||||
// i.e. seqlen_k <= 256 * kN0 (for kN0=64 -> seqlen_k <= 16384). MakeKargs asserts this.
|
||||
// The extra 1024B is jenga-specific: pipeline (block_fmha_pipeline_qr_ks_vs_async_jenga
|
||||
// .hpp:261) stages block_relation_onehot here. Do NOT copy this `+ 256*sizeof(int)` to
|
||||
// other sparse kernels (e.g. VSA) without first wiring a real reader.
|
||||
__shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)];
|
||||
|
||||
// if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d",
|
||||
|
||||
@@ -251,8 +251,7 @@ struct FmhaFwdVSAKernel
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
// allocate LDS
|
||||
// Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads.
|
||||
__shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
@@ -22,7 +22,7 @@ struct SpargeBlockMapKernel
|
||||
static constexpr index_t kN0 = Pipeline::kN0;
|
||||
static constexpr index_t D = Pipeline::D;
|
||||
|
||||
static constexpr index_t kAlignment = 16 / sizeof(QDataType);
|
||||
static constexpr index_t kAlignment = 16 / sizeof(QDataType); // 16B = dwordx4 load width
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
@@ -52,19 +52,18 @@ struct SpargeBlockMapKernel
|
||||
void* lut_ptr;
|
||||
void* valid_block_num_ptr;
|
||||
|
||||
// R20 K-stat workspace from Kernel A
|
||||
const void* pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] fp32
|
||||
const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8
|
||||
// K-block stats workspace produced by SpargeKStatsKernel
|
||||
const void*
|
||||
pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype)
|
||||
const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8
|
||||
|
||||
index_t N_k;
|
||||
|
||||
// R21A Phase 4: optional per-head topk (size = nhead_q floats).
|
||||
// nullptr => use scalar `topk` for all heads.
|
||||
// Per-head topk (size = nhead_q floats). Required (non-null).
|
||||
const float* topk_per_head;
|
||||
|
||||
// R21B: optional per-head cdfthreshd (size = nhead_q floats).
|
||||
// nullptr => use scalar `cdfthreshd` for all heads.
|
||||
// Only consulted on topk<=0 path; bench currently always uses topk path.
|
||||
// Per-head cdfthreshd (size = nhead_q floats). Required (non-null);
|
||||
// only consulted on topk<=0 path.
|
||||
const float* cdfthreshd_per_head;
|
||||
};
|
||||
|
||||
@@ -90,8 +89,8 @@ struct SpargeBlockMapKernel
|
||||
void* valid_block_num_ptr,
|
||||
const void* pooled_k_ws_ptr,
|
||||
const void* sim_k_ws_ptr,
|
||||
const float* topk_per_head = nullptr,
|
||||
const float* cdfthreshd_per_head = nullptr)
|
||||
const float* topk_per_head,
|
||||
const float* cdfthreshd_per_head)
|
||||
{
|
||||
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
|
||||
return Kargs{q_ptr,
|
||||
@@ -195,20 +194,15 @@ struct SpargeBlockMapKernel
|
||||
// Shared memory
|
||||
__shared__ char smem[Pipeline::GetSmemSize()];
|
||||
|
||||
// R20 K-stat workspace: pre-offset for this (b, hk).
|
||||
const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk;
|
||||
// K-stat workspace: pre-offset for this (b, hk).
|
||||
const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk;
|
||||
const index_t khead_off = (b * nhead_k + hk) * N_k;
|
||||
const auto* pooled_k_ws =
|
||||
reinterpret_cast<const float*>(kargs.pooled_k_ws_ptr) + khead_off * D;
|
||||
const auto* sim_k_ws =
|
||||
reinterpret_cast<const uint8_t*>(kargs.sim_k_ws_ptr) + khead_off;
|
||||
reinterpret_cast<const KDataType*>(kargs.pooled_k_ws_ptr) + khead_off * D;
|
||||
const auto* sim_k_ws = reinterpret_cast<const uint8_t*>(kargs.sim_k_ws_ptr) + khead_off;
|
||||
|
||||
// R21A Phase 4: per-head topk if provided, else scalar broadcast.
|
||||
const float topk_eff =
|
||||
(kargs.topk_per_head != nullptr) ? kargs.topk_per_head[hq] : kargs.topk;
|
||||
// R21B: per-head cdfthreshd if provided, else scalar broadcast.
|
||||
const float cdfthreshd_eff =
|
||||
(kargs.cdfthreshd_per_head != nullptr) ? kargs.cdfthreshd_per_head[hq] : kargs.cdfthreshd;
|
||||
const float topk_eff = kargs.topk_per_head[hq];
|
||||
const float cdfthreshd_eff = kargs.cdfthreshd_per_head[hq];
|
||||
|
||||
Pipeline{}(q_window,
|
||||
k_window,
|
||||
|
||||
@@ -40,15 +40,13 @@ struct SpargeKStatsKernel
|
||||
|
||||
float simthreshd1;
|
||||
|
||||
void* pooled_k_ptr; // [batch, nhead_k, N_k, D] fp32
|
||||
void* pooled_k_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype)
|
||||
void* sim_k_ptr; // [batch, nhead_k, N_k] uint8
|
||||
|
||||
index_t N_k;
|
||||
|
||||
// R21A Phase 4 + R21B fix: optional per-head simthreshd1.
|
||||
// Buffer is sized [nhead_q] floats to match SpargeAttn upstream contract
|
||||
// (utils.py:324, Headnum=q.size(1)). Kernel only indexes the first
|
||||
// nhead_k entries via [hk]. nullptr => use scalar `simthreshd1`.
|
||||
// Per-head simthreshd1 pointer (size = nhead_q floats; kernel indexes [hk] only).
|
||||
// Required (non-null); matches SpargeAttn upstream contract.
|
||||
const float* simthreshd1_per_head;
|
||||
};
|
||||
|
||||
@@ -62,7 +60,7 @@ struct SpargeKStatsKernel
|
||||
float simthreshd1,
|
||||
void* pooled_k_ptr,
|
||||
void* sim_k_ptr,
|
||||
const float* simthreshd1_per_head = nullptr)
|
||||
const float* simthreshd1_per_head)
|
||||
{
|
||||
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
|
||||
return Kargs{k_ptr,
|
||||
@@ -111,17 +109,15 @@ struct SpargeKStatsKernel
|
||||
{0, 0},
|
||||
Pipeline::MakeKBlockDistribution());
|
||||
|
||||
const index_t N_k = kargs.N_k;
|
||||
const index_t khead_off = (b * kargs.nhead_k + hk) * N_k;
|
||||
auto* pooled_k_out = reinterpret_cast<float*>(kargs.pooled_k_ptr) + (khead_off + kb) * D;
|
||||
auto* sim_k_out = reinterpret_cast<uint8_t*>(kargs.sim_k_ptr) + (khead_off + kb);
|
||||
const index_t N_k = kargs.N_k;
|
||||
const index_t khead_off = (b * kargs.nhead_k + hk) * N_k;
|
||||
auto* pooled_k_out =
|
||||
reinterpret_cast<KDataType*>(kargs.pooled_k_ptr) + (khead_off + kb) * D;
|
||||
auto* sim_k_out = reinterpret_cast<uint8_t*>(kargs.sim_k_ptr) + (khead_off + kb);
|
||||
|
||||
__shared__ char smem[Pipeline::GetSmemSize()];
|
||||
|
||||
// R21A Phase 4: per-head simthreshd1 if provided, else scalar broadcast.
|
||||
const float simthreshd1_eff = (kargs.simthreshd1_per_head != nullptr)
|
||||
? kargs.simthreshd1_per_head[hk]
|
||||
: kargs.simthreshd1;
|
||||
const float simthreshd1_eff = kargs.simthreshd1_per_head[hk];
|
||||
|
||||
Pipeline{}(k_window,
|
||||
kargs.seqlen_k,
|
||||
|
||||
@@ -262,7 +262,7 @@ struct SpargeBlockMapPipeline
|
||||
uint8_t* block_map_ptr,
|
||||
int32_t* lut_ptr,
|
||||
int32_t* valid_block_num_ptr,
|
||||
const float* __restrict__ pooled_k_ws_ptr,
|
||||
const KDataType* __restrict__ pooled_k_ws_ptr,
|
||||
const uint8_t* __restrict__ sim_k_ws_ptr,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
@@ -356,10 +356,10 @@ struct SpargeBlockMapPipeline
|
||||
|
||||
for(index_t kb = 0; kb < N_k; ++kb)
|
||||
{
|
||||
const float* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
|
||||
const KDataType* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
|
||||
float pooled_k_mean[KPerThread];
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_k_mean[k] = p_kb[k];
|
||||
pooled_k_mean[k] = type_convert<float>(p_kb[k]);
|
||||
|
||||
float dot = 0.f;
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
@@ -417,8 +417,7 @@ struct SpargeBlockMapPipeline
|
||||
// cdfthreshd path (topk <= 0) still requires normalised scores so the
|
||||
// accumulator `cumulative_prob` matches probabilities.
|
||||
const bool topk_active = (topk > 0.f);
|
||||
const float inv_sum =
|
||||
(!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
|
||||
const float inv_sum = (!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
|
||||
if(!topk_active)
|
||||
{
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
|
||||
@@ -49,8 +49,8 @@ struct SpargeKStatsPipeline
|
||||
index_t seqlen_k,
|
||||
index_t kb,
|
||||
float simthreshd1,
|
||||
float* __restrict__ pooled_k_out, // D floats
|
||||
uint8_t* __restrict__ sim_k_out, // 1 byte
|
||||
KDataType* __restrict__ pooled_k_out, // D KDataType (fp16/bf16)
|
||||
uint8_t* __restrict__ sim_k_out, // 1 byte
|
||||
void* smem_ptr) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
@@ -70,19 +70,19 @@ struct SpargeKStatsPipeline
|
||||
const index_t m_idx = lane_id / KThreads;
|
||||
|
||||
// pooled_k_mean: column sum then cross-warp reduce.
|
||||
// R21A: drop trailing sync (next cross_warp_reduce has its own leading sync).
|
||||
// Drop trailing sync (next cross_warp_reduce has its own leading sync).
|
||||
float pooled_k_mean[KPerThread];
|
||||
Base::template column_reduce_thread_and_warp<NPerThread>(k_data, pooled_k_mean);
|
||||
Base::template column_reduce_cross_warp<false>(pooled_k_mean, smem_reduce);
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_k_mean[k] *= inv_bs_k;
|
||||
|
||||
// R21A: write pooled_k_mean to global early so its register liveness ends here,
|
||||
// Write pooled_k_mean to global early so its register liveness ends here,
|
||||
// freeing VGPR before k_sum_hat becomes live.
|
||||
if(warp_id == 0 && m_idx == 0)
|
||||
{
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_k_out[k_idx * KPerThread + k] = pooled_k_mean[k];
|
||||
pooled_k_out[k_idx * KPerThread + k] = type_convert<KDataType>(pooled_k_mean[k]);
|
||||
}
|
||||
|
||||
// K row L2 norms + normalised column sum (k_sum_hat)
|
||||
@@ -91,7 +91,7 @@ struct SpargeKStatsPipeline
|
||||
|
||||
float k_sum_hat[KPerThread];
|
||||
Base::template column_reduce_normalised<NPerThread>(k_data, k_psq, k_sum_hat, bs_k);
|
||||
// R21A: drop trailing sync (no further smem read; only intra-warp shuffle + global write).
|
||||
// Drop trailing sync (no further smem read; only intra-warp shuffle + global write).
|
||||
Base::template column_reduce_cross_warp<false>(k_sum_hat, smem_reduce);
|
||||
|
||||
// sim_k = (||k_sum_hat||^2 / bs_k^2) > simthreshd1
|
||||
|
||||
Reference in New Issue
Block a user