mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
sparse_attn: split-launch dispatch + 3-mode PV-skip
- Per-head pv_threshold via head_remap LUT (CLI: -pv_threshold_per_head);
sentinel 1e30 routes to kEnablePVSkip=false bucket
- kEnablePVSkip bool → PVSkipMode enum {kNone, kPerWarp, kPerBlock};
new kPerBlock matches upstream sm80 (LDS vote, V loads unconditional).
CLI: -pv_mode={none,warp,block}, default warp
- README: PV-skip modes section + MI300X 3-curve sparsity chart
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,9 @@
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
// PVSkipMode enum lives in the sparge pipeline header; pull it in so the
|
||||
// kernel template arg can name it (R30: promote bool kEnablePVSkip_ to 3-way enum).
|
||||
#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
@@ -21,7 +24,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_, bool kEnablePVSkip_ = true>
|
||||
template <typename FmhaPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
PVSkipMode kPVSkipMode_ = PVSkipMode::kPerWave>
|
||||
struct FmhaFwdSpargeKernel
|
||||
{
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
@@ -30,7 +35,9 @@ struct FmhaFwdSpargeKernel
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
static constexpr bool kEnablePVSkip = kEnablePVSkip_;
|
||||
static constexpr PVSkipMode kPVSkipMode = kPVSkipMode_;
|
||||
// Legacy alias preserved: any non-kNone mode is "PV-skip enabled".
|
||||
static constexpr bool kEnablePVSkip = (kPVSkipMode_ != PVSkipMode::kNone);
|
||||
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
@@ -99,6 +106,15 @@ struct FmhaFwdSpargeKernel
|
||||
ck_tile::index_t nhead_ratio_qk;
|
||||
float scale_s;
|
||||
float pv_threshold;
|
||||
// R26 split-launch: when non-null, indexed by remapped i_nhead (post head_remap),
|
||||
// overrides scalar pv_threshold. Buffer length = num_head_q.
|
||||
const float* pv_threshold_per_head;
|
||||
// R26 split-launch: when non-null, i_nhead = head_remap_ptr[blockIdx.y].
|
||||
// Buffer length = nhead_in_launch. Null = identity (blockIdx.y directly).
|
||||
const int* head_remap_ptr;
|
||||
// R26 split-launch: gridDim.y when head_remap_ptr is active (== bucket size).
|
||||
// Kept for future host-side asserts / debug; kernel reads via blockIdx.y.
|
||||
ck_tile::index_t nhead_in_launch;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
@@ -165,7 +181,12 @@ struct FmhaFwdSpargeKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
// R26 split-launch (default-null preserves
|
||||
// backward compat = scalar mode).
|
||||
const float* pv_threshold_per_head = nullptr,
|
||||
const int* head_remap_ptr = nullptr,
|
||||
ck_tile::index_t nhead_in_launch = 0)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -185,6 +206,9 @@ struct FmhaFwdSpargeKernel
|
||||
scale_s,
|
||||
#endif
|
||||
pv_threshold,
|
||||
pv_threshold_per_head,
|
||||
head_remap_ptr,
|
||||
nhead_in_launch,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
@@ -224,7 +248,18 @@ struct FmhaFwdSpargeKernel
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
// R26 split-launch: if head_remap_ptr is set, translate the launch-local
|
||||
// head index to the original num_head_q-space index. Null pointer ->
|
||||
// identity (single-launch backward compat). The remap LUT load is uniform
|
||||
// across the wavefront (same blockIdx.y for all lanes), but the compiler
|
||||
// can't infer scalar-uniformity through a global ptr indirection, so we
|
||||
// broadcast via readfirstlane. Without this, dependent offset/buffer-
|
||||
// descriptor computations spill to VGPRs and buffer_load_dwordx4 inline
|
||||
// asm rejects the VGPR operand.
|
||||
const index_t i_nhead =
|
||||
(kargs.head_remap_ptr != nullptr)
|
||||
? __builtin_amdgcn_readfirstlane(kargs.head_remap_ptr[blockIdx.y])
|
||||
: static_cast<index_t>(blockIdx.y);
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
@@ -402,6 +437,23 @@ struct FmhaFwdSpargeKernel
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
// R26 split-launch: per-head pv_threshold override (null = scalar mode).
|
||||
// i_nhead is already scalar-broadcast in GetTileIndex; the load is uniform
|
||||
// and the resulting float lands in SGPRs naturally. We additionally route
|
||||
// via readfirstlane on the int representation as a defensive hint to keep
|
||||
// it scalar even when the compiler is conservative about float traffic.
|
||||
float pv_threshold_resolved;
|
||||
if(kargs.pv_threshold_per_head != nullptr)
|
||||
{
|
||||
const int raw = __builtin_amdgcn_readfirstlane(
|
||||
__builtin_bit_cast(int, kargs.pv_threshold_per_head[i_nhead]));
|
||||
pv_threshold_resolved = __builtin_bit_cast(float, raw);
|
||||
}
|
||||
else
|
||||
{
|
||||
pv_threshold_resolved = kargs.pv_threshold;
|
||||
}
|
||||
|
||||
auto o_acc_tile = FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
@@ -409,7 +461,7 @@ struct FmhaFwdSpargeKernel
|
||||
valid_block_num_value,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.pv_threshold,
|
||||
pv_threshold_resolved,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
|
||||
@@ -11,18 +11,40 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// R30: PV-skip mode enum. R25 A1 shipped a per-wavefront vote; R30 adds a
|
||||
// per-block consensus vote (matches upstream SpargeAttn kPerBlock semantics;
|
||||
// see R29 researcher report per_block_vload_guard.md). kNone disables the
|
||||
// skip path entirely (AST removed). The legacy bool `kEnablePVSkip_=true`
|
||||
// maps to kPerWave; `false` maps to kNone — preserved via codegen.
|
||||
enum class PVSkipMode : int
|
||||
{
|
||||
kNone = 0,
|
||||
kPerWave = 1,
|
||||
kPerBlock = 2,
|
||||
};
|
||||
|
||||
// Sparge variant of qr/ks/vs/async pipeline. Cloned from BlockFmhaPipelineQRKSVSAsyncVSA;
|
||||
// adds PV-skip per Q-tile (SpargeAttn paper 4.4). Kept as a separate file so the original
|
||||
// _vsa.hpp can remain frozen as an A/B baseline.
|
||||
//
|
||||
// R30: kPVSkipMode_ promoted from bool to 3-value enum {kNone, kPerWave, kPerBlock}.
|
||||
// kPerWave is the R25 A1 shipped path; kPerBlock adds a block-wide consensus AND vote
|
||||
// (1 LDS slot + 1 block_sync_lds) so all waves in a block agree before skipping the
|
||||
// PV mma. Per R29 audit, the V load / V->LDS store / cp_async pipeline stay
|
||||
// unconditional in BOTH per-wave and per-block modes (only the gemm_1 is gated).
|
||||
//
|
||||
// QUANT-HOOK: future int8/sage variant will add QScaleEnum template arg + per-tile descale Kargs;
|
||||
// _sparge_sage.hpp will live alongside this file and reuse the PV-skip path verbatim.
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy,
|
||||
bool kEnablePVSkip_ = true>
|
||||
typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy,
|
||||
PVSkipMode kPVSkipMode_ = PVSkipMode::kPerWave>
|
||||
struct BlockFmhaPipelineQRKSVSAsyncSparge
|
||||
{
|
||||
static constexpr bool kEnablePVSkip = kEnablePVSkip_;
|
||||
static constexpr PVSkipMode kPVSkipMode = kPVSkipMode_;
|
||||
// Legacy alias: true iff any PV-skip mode (per-wave or per-block) is active.
|
||||
// Kept so existing `if constexpr (kEnablePVSkip)` reads still compile.
|
||||
static constexpr bool kEnablePVSkip = (kPVSkipMode_ != PVSkipMode::kNone);
|
||||
static constexpr bool kPerBlockPVSkip = (kPVSkipMode_ == PVSkipMode::kPerBlock);
|
||||
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
@@ -140,7 +162,22 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
|
||||
|
||||
static constexpr const char* name = "qr_async";
|
||||
|
||||
// R30: per-block PV-skip needs one int32 LDS slot to broadcast the AND-vote
|
||||
// result across waves. Reserved at the TAIL of the pipeline's LDS budget
|
||||
// (after the existing K + V allocations), 4 bytes, aligned. When mode is
|
||||
// kNone or kPerWave the byte is unused; the sentinel cost is negligible
|
||||
// (4 bytes vs the multi-kB K/V tiles) so we always reserve it to keep the
|
||||
// smem layout uniform across modes — simpler than per-mode policy plumbing.
|
||||
static constexpr ck_tile::index_t kPerBlockVoteSlotBytes = 4;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>() + kPerBlockVoteSlotBytes;
|
||||
}
|
||||
|
||||
// R30: byte offset of the per-block vote flag from `smem_ptr`. Lives just
|
||||
// past the policy's K+V smem footprint.
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetPerBlockVoteSlotOffset()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
@@ -513,6 +550,69 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
|
||||
};
|
||||
const bool warp_skip = compute_warp_skip();
|
||||
|
||||
// ================================================================
|
||||
// R30: per-block PV-skip — block-wide AND vote over warp_skip.
|
||||
// Hand-rolled (no `block_and` primitive in CK-tile, no
|
||||
// `__syncthreads_and` analog — see R30 idiom catalog §7.5).
|
||||
//
|
||||
// Protocol:
|
||||
// 1. Lane 0 of each wave atomicAnd's its warp_skip int into a
|
||||
// shared LDS sentinel (initialised to 1 by lane 0 of wave 0
|
||||
// before the vote).
|
||||
// 2. block_sync_lds() — all stores visible, all waves rendezvous
|
||||
// (uses the same s_waitcnt+s_barrier discipline as the K/V
|
||||
// LDS chain; lgkmcnt accounting stays consistent — idiom
|
||||
// §3.1 / §4.2).
|
||||
// 3. All lanes read the sentinel back into a register. The
|
||||
// result is wave-uniform (and effectively SGPR after
|
||||
// readfirstlane) — used to gate gemm_1 at :607 / :665 below.
|
||||
//
|
||||
// Cost: 1 LDS init + 1 atomicAnd + 1 block_sync_lds + 1 LDS load.
|
||||
// The vote slot lives at `smem_ptr + GetPerBlockVoteSlotOffset()`,
|
||||
// 4 bytes past the policy K+V budget (see GetSmemSize override).
|
||||
// No interaction with LdsSeq rotation slots.
|
||||
//
|
||||
// V load / V->LDS store / cp_async pipeline stay UNCONDITIONAL in
|
||||
// both per-wave and per-block modes — matches upstream SpargeAttn
|
||||
// (R29 audit) and CK-tile LDS-rotation discipline.
|
||||
// ================================================================
|
||||
bool block_skip = false;
|
||||
if constexpr(kPerBlockPVSkip)
|
||||
{
|
||||
// Carve a 4-byte uint32 slot at the LDS tail. The cast is safe:
|
||||
// GetSmemSize() bumped the smem_ptr allocation by 4 bytes (see
|
||||
// pipeline override above), so the slot is dedicated to this
|
||||
// pipeline instance and never reused by K/V tiles.
|
||||
auto* vote_slot = reinterpret_cast<uint32_t*>(static_cast<char*>(smem_ptr) +
|
||||
GetPerBlockVoteSlotOffset());
|
||||
|
||||
const int lane_id = threadIdx.x % warpSize;
|
||||
const int warp_id = threadIdx.x / warpSize;
|
||||
|
||||
// Initialise the sentinel to 1 (skip-everything) before any
|
||||
// wave votes. Only one thread does the init; the subsequent
|
||||
// block_sync_lds() makes it visible to all waves.
|
||||
if(warp_id == 0 && lane_id == 0)
|
||||
{
|
||||
*vote_slot = 1u;
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Each wave contributes its warp_skip (already wave-uniform
|
||||
// after the butterfly in compute_warp_skip). Lane 0 of each
|
||||
// wave issues the atomicAnd; other lanes are idle. The atomic
|
||||
// is on LDS (s_or_b32 / ds_and_b32), much cheaper than global.
|
||||
if(lane_id == 0)
|
||||
{
|
||||
atomicAnd(vote_slot, warp_skip ? 1u : 0u);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Broadcast the consensus back to every lane.
|
||||
const uint32_t consensus = *vote_slot;
|
||||
block_skip = (consensus != 0u);
|
||||
}
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
@@ -530,6 +630,10 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
|
||||
// R25 redesign D: when kEnablePVSkip + warp_skip, we zero this
|
||||
// warp's owned rows of p_compute so the unconditional gemm_1
|
||||
// contributes zero to o_acc, and skip the rowsum.
|
||||
// R30: per-block mode uses block_skip (uniform across waves) and
|
||||
// additionally skips gemm_1 itself (see guard at the gemm_1 site
|
||||
// below). The p_compute zeroing remains so rowsum_p -> 0 and
|
||||
// `l += rowsum_p` is a no-op for skipped iters.
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
@@ -538,7 +642,15 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
if constexpr(kEnablePVSkip)
|
||||
if constexpr(kPerBlockPVSkip)
|
||||
{
|
||||
if(block_skip)
|
||||
{
|
||||
p_compute(i_j_idx) = SMPLComputeDataType{0};
|
||||
return;
|
||||
}
|
||||
}
|
||||
else if constexpr(kEnablePVSkip)
|
||||
{
|
||||
if(warp_skip)
|
||||
{
|
||||
@@ -603,15 +715,39 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
|
||||
number<-1>{},
|
||||
bool_constant<false>{}); // load next v_buf
|
||||
}
|
||||
// block_sync_lds() stays UNCONDITIONAL — it is the
|
||||
// workgroup barrier the V->LDS rotation chain requires
|
||||
// (idiom catalog §3.1 / §4.1). Only the gemm_1 MFMA is
|
||||
// gated on block_skip when in per-block mode.
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
|
||||
if constexpr(kPerBlockPVSkip)
|
||||
{
|
||||
if(!block_skip)
|
||||
{
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1,
|
||||
kK1>{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1,
|
||||
kK1>{}));
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<VLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -659,16 +795,37 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge
|
||||
k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
}
|
||||
// tail — gemm_1 runs unconditionally under redesign D.
|
||||
// tail — gemm_1 runs unconditionally under redesign D (per-wave).
|
||||
// R30: per-block mode gates the MFMA on block_skip; block_sync_lds
|
||||
// still runs unconditionally (workgroup barrier for LDS rotation).
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
|
||||
if constexpr(kPerBlockPVSkip)
|
||||
{
|
||||
if(!block_skip)
|
||||
{
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1,
|
||||
kK1>{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1,
|
||||
kK1>{}));
|
||||
}
|
||||
}
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user