sparse_attn: split-launch dispatch + 3-mode PV-skip

- Per-head pv_threshold via head_remap LUT (CLI: -pv_threshold_per_head);
  sentinel 1e30 routes to kEnablePVSkip=false bucket
- kEnablePVSkip bool → PVSkipMode enum {kNone, kPerWarp, kPerBlock};
  new kPerBlock matches upstream sm80 (LDS vote, V loads unconditional).
  CLI: -pv_mode={none,warp,block}, default warp
- README: PV-skip modes section + MI300X 3-curve sparsity chart

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
Gino Lu
2026-05-19 21:45:23 -04:00
parent 304c1f9244
commit d939c3b4fc
8 changed files with 585 additions and 95 deletions

View File

@@ -7,6 +7,9 @@
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
// PVSkipMode enum lives in the sparge pipeline header; pull it in so the
// kernel template arg can name it (R30: promote bool kEnablePVSkip_ to 3-way enum).
#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp"
#include <string>
#include <type_traits>
@@ -21,7 +24,9 @@
namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_, bool kEnablePVSkip_ = true>
template <typename FmhaPipeline_,
typename EpiloguePipeline_,
PVSkipMode kPVSkipMode_ = PVSkipMode::kPerWave>
struct FmhaFwdSpargeKernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
@@ -30,7 +35,9 @@ struct FmhaFwdSpargeKernel
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
static constexpr bool kEnablePVSkip = kEnablePVSkip_;
static constexpr PVSkipMode kPVSkipMode = kPVSkipMode_;
// Legacy alias preserved: any non-kNone mode is "PV-skip enabled".
static constexpr bool kEnablePVSkip = (kPVSkipMode_ != PVSkipMode::kNone);
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
@@ -99,6 +106,15 @@ struct FmhaFwdSpargeKernel
ck_tile::index_t nhead_ratio_qk;
float scale_s;
float pv_threshold;
// R26 split-launch: when non-null, indexed by remapped i_nhead (post head_remap),
// overrides scalar pv_threshold. Buffer length = num_head_q.
const float* pv_threshold_per_head;
// R26 split-launch: when non-null, i_nhead = head_remap_ptr[blockIdx.y].
// Buffer length = nhead_in_launch. Null = identity (blockIdx.y directly).
const int* head_remap_ptr;
// R26 split-launch: gridDim.y when head_remap_ptr is active (== bucket size).
// Kept for future host-side asserts / debug; kernel reads via blockIdx.y.
ck_tile::index_t nhead_in_launch;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
@@ -165,7 +181,12 @@ struct FmhaFwdSpargeKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
ck_tile::index_t mask_type,
// R26 split-launch (default-null preserves
// backward compat = scalar mode).
const float* pv_threshold_per_head = nullptr,
const int* head_remap_ptr = nullptr,
ck_tile::index_t nhead_in_launch = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -185,6 +206,9 @@ struct FmhaFwdSpargeKernel
scale_s,
#endif
pv_threshold,
pv_threshold_per_head,
head_remap_ptr,
nhead_in_launch,
stride_q,
stride_k,
stride_v,
@@ -224,7 +248,18 @@ struct FmhaFwdSpargeKernel
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
// R26 split-launch: if head_remap_ptr is set, translate the launch-local
// head index to the original num_head_q-space index. Null pointer ->
// identity (single-launch backward compat). The remap LUT load is uniform
// across the wavefront (same blockIdx.y for all lanes), but the compiler
// can't infer scalar-uniformity through a global ptr indirection, so we
// broadcast via readfirstlane. Without this, dependent offset/buffer-
// descriptor computations spill to VGPRs and buffer_load_dwordx4 inline
// asm rejects the VGPR operand.
const index_t i_nhead =
(kargs.head_remap_ptr != nullptr)
? __builtin_amdgcn_readfirstlane(kargs.head_remap_ptr[blockIdx.y])
: static_cast<index_t>(blockIdx.y);
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
@@ -402,6 +437,23 @@ struct FmhaFwdSpargeKernel
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
// R26 split-launch: per-head pv_threshold override (null = scalar mode).
// i_nhead is already scalar-broadcast in GetTileIndex; the load is uniform
// and the resulting float lands in SGPRs naturally. We additionally route
// via readfirstlane on the int representation as a defensive hint to keep
// it scalar even when the compiler is conservative about float traffic.
float pv_threshold_resolved;
if(kargs.pv_threshold_per_head != nullptr)
{
const int raw = __builtin_amdgcn_readfirstlane(
__builtin_bit_cast(int, kargs.pv_threshold_per_head[i_nhead]));
pv_threshold_resolved = __builtin_bit_cast(float, raw);
}
else
{
pv_threshold_resolved = kargs.pv_threshold;
}
auto o_acc_tile = FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
@@ -409,7 +461,7 @@ struct FmhaFwdSpargeKernel
valid_block_num_value,
mask,
kargs.scale_s,
kargs.pv_threshold,
pv_threshold_resolved,
variant,
variant_params,
block_indices,

View File

@@ -11,18 +11,40 @@
namespace ck_tile {
// R30: PV-skip mode enum. R25 A1 shipped a per-wavefront vote; R30 adds a
// per-block consensus vote (matches upstream SpargeAttn kPerBlock semantics;
// see R29 researcher report per_block_vload_guard.md). kNone disables the
// skip path entirely (AST removed). The legacy bool `kEnablePVSkip_=true`
// maps to kPerWave; `false` maps to kNone — preserved via codegen.
enum class PVSkipMode : int
{
kNone = 0,
kPerWave = 1,
kPerBlock = 2,
};
// Sparge variant of qr/ks/vs/async pipeline. Cloned from BlockFmhaPipelineQRKSVSAsyncVSA;
// adds PV-skip per Q-tile (SpargeAttn paper 4.4). Kept as a separate file so the original
// _vsa.hpp can remain frozen as an A/B baseline.
//
// R30: kPVSkipMode_ promoted from bool to 3-value enum {kNone, kPerWave, kPerBlock}.
// kPerWave is the R25 A1 shipped path; kPerBlock adds a block-wide consensus AND vote
// (1 LDS slot + 1 block_sync_lds) so all waves in a block agree before skipping the
// PV mma. Per R29 audit, the V load / V->LDS store / cp_async pipeline stay
// unconditional in BOTH per-wave and per-block modes (only the gemm_1 is gated).
//
// QUANT-HOOK: future int8/sage variant will add QScaleEnum template arg + per-tile descale Kargs;
// _sparge_sage.hpp will live alongside this file and reuse the PV-skip path verbatim.
template <typename Problem_,
typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy,
bool kEnablePVSkip_ = true>
typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy,
PVSkipMode kPVSkipMode_ = PVSkipMode::kPerWave>
struct BlockFmhaPipelineQRKSVSAsyncSparge
{
static constexpr bool kEnablePVSkip = kEnablePVSkip_;
static constexpr PVSkipMode kPVSkipMode = kPVSkipMode_;
// Legacy alias: true iff any PV-skip mode (per-wave or per-block) is active.
// Kept so existing `if constexpr (kEnablePVSkip)` reads still compile.
static constexpr bool kEnablePVSkip = (kPVSkipMode_ != PVSkipMode::kNone);
static constexpr bool kPerBlockPVSkip = (kPVSkipMode_ == PVSkipMode::kPerBlock);
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
@@ -140,7 +162,22 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
static constexpr const char* name = "qr_async";
// R30: per-block PV-skip needs one int32 LDS slot to broadcast the AND-vote
// result across waves. Reserved at the TAIL of the pipeline's LDS budget
// (after the existing K + V allocations), 4 bytes, aligned. When mode is
// kNone or kPerWave the byte is unused; the sentinel cost is negligible
// (4 bytes vs the multi-kB K/V tiles) so we always reserve it to keep the
// smem layout uniform across modes — simpler than per-mode policy plumbing.
static constexpr ck_tile::index_t kPerBlockVoteSlotBytes = 4;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>() + kPerBlockVoteSlotBytes;
}
// R30: byte offset of the per-block vote flag from `smem_ptr`. Lives just
// past the policy's K+V smem footprint.
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetPerBlockVoteSlotOffset()
{
return Policy::template GetSmemSize<Problem>();
}
@@ -513,6 +550,69 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
};
const bool warp_skip = compute_warp_skip();
// ================================================================
// R30: per-block PV-skip — block-wide AND vote over warp_skip.
// Hand-rolled (no `block_and` primitive in CK-tile, no
// `__syncthreads_and` analog — see R30 idiom catalog §7.5).
//
// Protocol:
// 1. Lane 0 of each wave atomicAnd's its warp_skip int into a
// shared LDS sentinel (initialised to 1 by lane 0 of wave 0
// before the vote).
// 2. block_sync_lds() — all stores visible, all waves rendezvous
// (uses the same s_waitcnt+s_barrier discipline as the K/V
// LDS chain; lgkmcnt accounting stays consistent — idiom
// §3.1 / §4.2).
// 3. All lanes read the sentinel back into a register. The
// result is wave-uniform (and effectively SGPR after
// readfirstlane) — used to gate gemm_1 at :607 / :665 below.
//
// Cost: 1 LDS init + 1 atomicAnd + 1 block_sync_lds + 1 LDS load.
// The vote slot lives at `smem_ptr + GetPerBlockVoteSlotOffset()`,
// 4 bytes past the policy K+V budget (see GetSmemSize override).
// No interaction with LdsSeq rotation slots.
//
// V load / V->LDS store / cp_async pipeline stay UNCONDITIONAL in
// both per-wave and per-block modes — matches upstream SpargeAttn
// (R29 audit) and CK-tile LDS-rotation discipline.
// ================================================================
bool block_skip = false;
if constexpr(kPerBlockPVSkip)
{
// Carve a 4-byte uint32 slot at the LDS tail. The cast is safe:
// GetSmemSize() bumped the smem_ptr allocation by 4 bytes (see
// pipeline override above), so the slot is dedicated to this
// pipeline instance and never reused by K/V tiles.
auto* vote_slot = reinterpret_cast<uint32_t*>(static_cast<char*>(smem_ptr) +
GetPerBlockVoteSlotOffset());
const int lane_id = threadIdx.x % warpSize;
const int warp_id = threadIdx.x / warpSize;
// Initialise the sentinel to 1 (skip-everything) before any
// wave votes. Only one thread does the init; the subsequent
// block_sync_lds() makes it visible to all waves.
if(warp_id == 0 && lane_id == 0)
{
*vote_slot = 1u;
}
block_sync_lds();
// Each wave contributes its warp_skip (already wave-uniform
// after the butterfly in compute_warp_skip). Lane 0 of each
// wave issues the atomicAnd; other lanes are idle. The atomic
// is on LDS (s_or_b32 / ds_and_b32), much cheaper than global.
if(lane_id == 0)
{
atomicAnd(vote_slot, warp_skip ? 1u : 0u);
}
block_sync_lds();
// Broadcast the consensus back to every lane.
const uint32_t consensus = *vote_slot;
block_skip = (consensus != 0u);
}
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
if constexpr(FmhaMask::IsMasking)
{
@@ -530,6 +630,10 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
// R25 redesign D: when kEnablePVSkip + warp_skip, we zero this
// warp's owned rows of p_compute so the unconditional gemm_1
// contributes zero to o_acc, and skip the rowsum.
// R30: per-block mode uses block_skip (uniform across waves) and
// additionally skips gemm_1 itself (see guard at the gemm_1 site
// below). The p_compute zeroing remains so rowsum_p -> 0 and
// `l += rowsum_p` is a no-op for skipped iters.
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
@@ -538,7 +642,15 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(kEnablePVSkip)
if constexpr(kPerBlockPVSkip)
{
if(block_skip)
{
p_compute(i_j_idx) = SMPLComputeDataType{0};
return;
}
}
else if constexpr(kEnablePVSkip)
{
if(warp_skip)
{
@@ -603,15 +715,39 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
number<-1>{},
bool_constant<false>{}); // load next v_buf
}
// block_sync_lds() stays UNCONDITIONAL — it is the
// workgroup barrier the V->LDS rotation chain requires
// (idiom catalog §3.1 / §4.1). Only the gemm_1 MFMA is
// gated on block_skip when in per-block mode.
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(kPerBlockPVSkip)
{
if(!block_skip)
{
gemm_1(
o_acc,
get_slice_tile(p,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1,
kK1>{}));
}
}
else
{
gemm_1(o_acc,
get_slice_tile(p,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1,
kK1>{}));
}
if constexpr(std::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
@@ -659,16 +795,37 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
}
// tail — gemm_1 runs unconditionally under redesign D.
// tail — gemm_1 runs unconditionally under redesign D (per-wave).
// R30: per-block mode gates the MFMA on block_skip; block_sync_lds
// still runs unconditionally (workgroup barrier for LDS rotation).
{
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
if constexpr(kPerBlockPVSkip)
{
if(!block_skip)
{
gemm_1(
o_acc,
get_slice_tile(
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1,
kK1>{}));
}
}
else
{
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1,
kK1>{}));
}
}
} while(i_total_loops < num_total_loop);