mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
CK-UA: decouple K/V DRAM loads per warp-group + early V read (FA4 fp8)
Split the cooperative K/V cache loads across the two FA4 warp groups so each group owns exactly one tile's DRAM load and address arithmetic: WG0 loads V, WG1 loads K, and both read from the shared LDS buffers. - kFA4WG0LoadsV / kFA4WG1LoadsK policy flags + GetVLoadNumWarps / GetKLoadNumWarps: the owning group's 4 waves alone fill the tile via 4-warp descriptors; the partner skips the load and reads from LDS. - High-warp-group support for the raw async path: the raw store bakes the absolute warp id into the LDS M0, so WG1 (waves 4-7) needs a base shift (GetKStoreWarpShift / WarpIdShift in MakeKLdsStoreBlockDescriptor) to map back to the 4-warp layout, plus WG-relative (warp % NumWarps) page offsets so the gather token positions are correct. - Stage B: move each tile's V LDS read into the PRECEDING softmax phase so the read latency hides under softmax VALU. Safe because V is now single- group-owned; uses drain-before-barrier (vmcnt<0> then s_barrier) so all 4 cooperating writer waves' slices are published before the read. - Gate per-tile offset refresh per warp-group (WG0 refreshes V, WG1 K), so each wave fetches a block-table page index for one tile instead of both; loop counters stay uniform. Validated 0% mismatch vs GPU reference, causal + non-causal, sq 256..8192. Net latency vs the cooperative baseline: causal ~-3-4.6%, non-causal ~-2-4.7% across sq 2048..16384 (d128 fp8). Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -488,6 +488,19 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
const index_t warp_group_id = get_warp_id() / 4;
|
||||
|
||||
// FA4 "WG0 loads V": warp group 0's 4 waves load the FULL V tile into
|
||||
// the shared V LDS buffer (V descriptors use VLoadNumWarps == 4 waves);
|
||||
// warp group 1 skips the V DRAM load and relies on the inter-phase
|
||||
// barrier for residency. v_load_active gates the async V load issue.
|
||||
constexpr index_t VLoadNumWarps = Policy::template GetVLoadNumWarps<Problem>();
|
||||
constexpr index_t KLoadNumWarps = Policy::template GetKLoadNumWarps<Problem>();
|
||||
constexpr index_t NumWarpGroups_ = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
|
||||
const bool v_load_active =
|
||||
(!Policy::kFA4WG0LoadsV) || (NumWarpGroups_ != 2) || (warp_group_id == 0);
|
||||
// Symmetric: warp group 1 alone loads K (WG0 reads from shared LDS).
|
||||
const bool k_load_active =
|
||||
(!Policy::kFA4WG1LoadsK) || (NumWarpGroups_ != 2) || (warp_group_id == 1);
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
|
||||
@@ -531,34 +544,41 @@ struct UnifiedAttentionPipeline
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
constexpr index_t KStoreWarpShift = Policy::template GetKStoreWarpShift<Problem>();
|
||||
auto k_lds_window_store = generate_tuple(
|
||||
[&](auto i_buf) {
|
||||
return make_lds_tile_window<KDataType>(
|
||||
smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
|
||||
smem_ptr,
|
||||
Policy::template MakeKLdsStoreBlockDescriptor<Problem,
|
||||
KLoadNumWarps,
|
||||
KStoreWarpShift>(i_buf));
|
||||
},
|
||||
number<2>{});
|
||||
|
||||
auto v_lds_window_store = generate_tuple(
|
||||
[&](auto i_buf) {
|
||||
return make_lds_tile_window<KDataType>(
|
||||
smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
|
||||
smem_ptr,
|
||||
Policy::template MakeVLdsStoreBlockDescriptor<Problem, VLoadNumWarps>(i_buf));
|
||||
},
|
||||
number<2>{});
|
||||
|
||||
statically_indexed_array<decltype(make_tile_window(
|
||||
make_lds_tile_window<KDataType>(
|
||||
nullptr,
|
||||
Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
|
||||
Policy::template MakeKRegTileDistribution<Problem>())),
|
||||
2>
|
||||
statically_indexed_array<
|
||||
decltype(make_tile_window(
|
||||
make_lds_tile_window<KDataType>(
|
||||
nullptr,
|
||||
Policy::template MakeKLdsLoadBlockDescriptor<Problem, KLoadNumWarps>()),
|
||||
Policy::template MakeKRegTileDistribution<Problem>())),
|
||||
2>
|
||||
k_lds_window_load;
|
||||
|
||||
statically_indexed_array<decltype(make_tile_window(
|
||||
make_lds_tile_window<VDataType>(
|
||||
nullptr,
|
||||
Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
|
||||
Policy::template MakeVRegTileDistribution<Problem>())),
|
||||
2>
|
||||
statically_indexed_array<
|
||||
decltype(make_tile_window(
|
||||
make_lds_tile_window<VDataType>(
|
||||
nullptr,
|
||||
Policy::template MakeVLdsLoadBlockDescriptor<Problem, VLoadNumWarps>()),
|
||||
Policy::template MakeVRegTileDistribution<Problem>())),
|
||||
2>
|
||||
v_lds_window_load;
|
||||
|
||||
decltype(make_static_distributed_tensor<QDataType>(
|
||||
@@ -619,7 +639,7 @@ struct UnifiedAttentionPipeline
|
||||
k_lds_window_load(idx) = make_tile_window(
|
||||
make_lds_tile_window<KDataType>(
|
||||
static_cast<char*>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
|
||||
Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
|
||||
Policy::template MakeKLdsLoadBlockDescriptor<Problem, KLoadNumWarps>()),
|
||||
Policy::template MakeKRegTileDistribution<Problem>());
|
||||
});
|
||||
|
||||
@@ -628,7 +648,8 @@ struct UnifiedAttentionPipeline
|
||||
make_tile_window(make_lds_tile_window<VDataType>(
|
||||
static_cast<char*>(smem_ptr) +
|
||||
(idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
|
||||
Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
|
||||
Policy::template MakeVLdsLoadBlockDescriptor<Problem,
|
||||
VLoadNumWarps>()),
|
||||
Policy::template MakeVRegTileDistribution<Problem>());
|
||||
});
|
||||
|
||||
@@ -714,8 +735,8 @@ struct UnifiedAttentionPipeline
|
||||
// a large negative relative offset that the HW OOB check clamps to 0).
|
||||
// A robust fix would either plumb long_index_t through the gather load
|
||||
// path or compute a per-batch min-page shift in a pre-pass.
|
||||
const auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
|
||||
const auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
|
||||
const auto k_dist = Policy::template MakeKDramTileDistribution<Problem, KLoadNumWarps>();
|
||||
const auto v_dist = Policy::template MakeVDramTileDistribution<Problem, VLoadNumWarps>();
|
||||
using KDstrType = decltype(k_dist);
|
||||
using VDstrType = decltype(v_dist);
|
||||
constexpr index_t KNRepeat =
|
||||
@@ -729,8 +750,19 @@ struct UnifiedAttentionPipeline
|
||||
VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] *
|
||||
VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}];
|
||||
|
||||
const auto k_thread_coord = k_dist.calculate_index();
|
||||
const auto v_thread_coord = v_dist.calculate_index();
|
||||
// WG-relative warp index for the gather page-offset computation. When a
|
||||
// single warp group loads a tile (V by WG0 / K by WG1), only that
|
||||
// group's waves issue the load, and their absolute warp ids must be
|
||||
// folded back into [0, NumWarps) so k_thread_n_pos / v_thread_n_pos (the
|
||||
// per-wave token position baked into page_idx) match the group-relative
|
||||
// distribution. For the cooperative case NumWarps == full block, so the
|
||||
// modulo is the identity. The scatter-gather's own get_partition_index
|
||||
// use is harmless here: the gather (token) dim is zeroed and replaced by
|
||||
// page_idx, and the remaining (head-dim) coordinate is lane-based.
|
||||
const auto k_part = ck_tile::array<index_t, 2>{get_warp_id() % KLoadNumWarps, get_lane_id()};
|
||||
const auto v_part = ck_tile::array<index_t, 2>{get_warp_id() % VLoadNumWarps, get_lane_id()};
|
||||
const auto k_thread_coord = k_dist.calculate_index(k_part);
|
||||
const auto v_thread_coord = v_dist.calculate_index(v_part);
|
||||
const index_t k_thread_n_pos = k_thread_coord[number<0>{}];
|
||||
const index_t v_thread_n_pos = v_thread_coord[number<0>{}];
|
||||
|
||||
@@ -1456,12 +1488,24 @@ struct UnifiedAttentionPipeline
|
||||
// num_blocks_start.
|
||||
const index_t num_iters_per_split = num_total_loop - num_blocks_start;
|
||||
auto K_mem_load = [&](auto k_lds_write_idx) {
|
||||
if(cache_ptr_int32_overflow_possible)
|
||||
async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
else
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
// FA4 "WG1 loads K": only warp group 1's waves issue the K async load
|
||||
// (its KLoadNumWarps==4 layout fills the full shared K tile). WG0
|
||||
// skips it and reads K from shared LDS (barrier-synchronized).
|
||||
if(k_load_active)
|
||||
{
|
||||
if(cache_ptr_int32_overflow_possible)
|
||||
async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
else
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
}
|
||||
k_block_idx++;
|
||||
if(k_block_idx < num_iters_per_split)
|
||||
// Only the K-loading warp group needs K offsets refreshed: with
|
||||
// kFA4WG1LoadsK, WG0 never issues a K load, so computing its page
|
||||
// offsets (incl. the block-table page-index ds_read) is pure waste.
|
||||
// k_block_idx itself stays uniform across all waves so loop control
|
||||
// and buffer parity never diverge. Gating here also means each wave
|
||||
// fetches a page-table index for exactly ONE tile (K *or* V), not both.
|
||||
if(k_load_active && k_block_idx < num_iters_per_split)
|
||||
{
|
||||
refresh_k_offsets(k_block_idx);
|
||||
if constexpr(kRebaseKSrd)
|
||||
@@ -1474,12 +1518,22 @@ struct UnifiedAttentionPipeline
|
||||
};
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
if(cache_ptr_int32_overflow_possible)
|
||||
async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
else
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
// FA4 "WG0 loads V": only warp group 0's waves issue the V async
|
||||
// load (its VLoadNumWarps==4 layout fills the full shared V tile).
|
||||
// WG1 skips the load; bookkeeping (v_block_idx / offsets) stays
|
||||
// uniform across all waves so the loop's scalar state never diverges.
|
||||
if(v_load_active)
|
||||
{
|
||||
if(cache_ptr_int32_overflow_possible)
|
||||
async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
else
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
}
|
||||
v_block_idx++;
|
||||
if(v_block_idx < num_iters_per_split)
|
||||
// Symmetric to K: only the V-loading warp group (WG0) refreshes V
|
||||
// offsets; WG1 skips it (it never issues a V load). v_block_idx stays
|
||||
// uniform for loop/buffer bookkeeping.
|
||||
if(v_load_active && v_block_idx < num_iters_per_split)
|
||||
{
|
||||
refresh_v_offsets(v_block_idx);
|
||||
if constexpr(kRebaseVSrd)
|
||||
@@ -2165,41 +2219,23 @@ struct UnifiedAttentionPipeline
|
||||
auto gemm1 = number<1>{};
|
||||
|
||||
// MATRIX phase: deferred PV(k-1) then QK(k). Pure matrix pipe.
|
||||
// Consumes V(1-pi) / K(pi) resident in LDS; the union kv_tile holds
|
||||
// v_tile for the PV then is overwritten with k_tile for the QK
|
||||
// (V_lds → gemm1 → K_lds → gemm0 ordering).
|
||||
// V-read hoist: issue the PV gemm's LDS→register read (v_rd == pi)
|
||||
// EARLY — before the matrix phase's compute — so its ~LDS-latency
|
||||
// overlaps the address-calc VALU of prefetch() (WG0) / the barrier
|
||||
// exit (WG1) instead of being exposed right before the PV MFMA. The
|
||||
// V buffer pi was populated by a prior prefetch and already waited
|
||||
// on (vmcnt<0> at phase entry); prefetch writes K-buf[pi]/V-buf[1-pi]
|
||||
// so there is no aliasing with this read.
|
||||
// V-read hoist: issue the PV gemm's LDS→register read (v_rd == pi)
|
||||
// EARLY — before the matrix phase's compute — so its ~LDS-latency
|
||||
// overlaps the address-calc VALU of prefetch() (WG0) / the barrier
|
||||
// exit (WG1) instead of being exposed right before the PV MFMA. The
|
||||
// V buffer pi was populated by a prior prefetch and already waited
|
||||
// on (vmcnt<0> at phase entry); prefetch writes K-buf[pi]/V-buf[1-pi]
|
||||
// so there is no aliasing with this read.
|
||||
// Consumes V(pi) / K(1-pi) resident in LDS; kv_tile holds v_tile for
|
||||
// the PV and (separately) k_tile for the QK.
|
||||
//
|
||||
// NOTE: do NOT also hoist K_lds_load here. The QK gemm reads K-buf
|
||||
// [1-pi], and in the WG1 softmax-first prologue that buffer is not
|
||||
// yet guaranteed resident this early (its async load completes a
|
||||
// phase later) — hoisting K corrupts long-context runs. K stays
|
||||
// issued between the PV and QK MFMAs (its latency hides under PV).
|
||||
// V-read hoist: issue the PV gemm's LDS→register read (v_rd == pi)
|
||||
// EARLY — before the matrix phase's compute — so its ~LDS-latency
|
||||
// overlaps the address-calc VALU of prefetch() (WG0) / the barrier
|
||||
// exit (WG1) instead of being exposed right before the PV MFMA.
|
||||
// V-read into SOFTMAX (Stage B): the PV gemm's V tile (v_rd == pi) is
|
||||
// now read in the *preceding* SOFTMAX phase rather than at the top of
|
||||
// this MATRIX phase, so its ~LDS latency overlaps the full softmax
|
||||
// VALU (exp / rowsum / P-cvt) instead of only the prefetch address
|
||||
// calc. This is safe now that V is loaded by a single warp group
|
||||
// (kFA4WG0LoadsV): WG0 reads V it loaded itself, so its own vmcnt<0>
|
||||
// proves residency (no partner dependency); WG1 reads an already-
|
||||
// barrier-published V buffer. The pre-read lands in v_tile and this
|
||||
// MATRIX phase consumes it directly (see fa4_softmax / the WG0 prime).
|
||||
//
|
||||
// NOTE: do NOT also hoist K_lds_load. K/V are loaded COOPERATIVELY
|
||||
// by both warp groups; a wave's vmcnt only drains its OWN async
|
||||
// loads, not the partner group's half, so a cooperatively-filled
|
||||
// buffer is reliably resident only deeper into the phase. The PV
|
||||
// gemm provides exactly that slack for the K read — moving K ahead
|
||||
// of PV races the partner's load completion and corrupts long
|
||||
// contexts. K stays issued between the PV and QK MFMAs.
|
||||
// NOTE: do NOT hoist K_lds_load the same way. The QK gemm reads K-buf
|
||||
// [1-pi] which the partner group (WG1) loads; WG0 has no own-vmcnt
|
||||
// proof of its residency this early, only the barrier deeper in the
|
||||
// phase. K stays issued between the PV and QK MFMAs (latency under PV).
|
||||
auto fa4_vload = [&](auto pi) { V_lds_load(pi); };
|
||||
|
||||
auto fa4_matrix = [&](auto pi) {
|
||||
@@ -2207,14 +2243,15 @@ struct UnifiedAttentionPipeline
|
||||
auto qk_sp = number<1>{} - pi; // QK target slot
|
||||
auto k_rd = number<1>{} - pi;
|
||||
|
||||
s_waitcnt_lgkmcnt<0>(); // wait the hoisted fa4_vload(pi)
|
||||
s_waitcnt_lgkmcnt<0>(); // wait the V pre-read issued in prev SOFTMAX
|
||||
gemm(pv_sp, gemm1); // o_acc += P(pi) @ V(k-1)
|
||||
// K read into its OWN registers (kv_tile no longer a union), so
|
||||
// this ds_read executes on the LSU *during* the PV MFMA above
|
||||
// instead of waiting for it to retire. The sched_barrier pins it
|
||||
// here (program order) so it is NOT hoisted above the PV gemm --
|
||||
// that would race the partner WG's cooperative K load and
|
||||
// corrupt long contexts (the residency hazard documented above).
|
||||
// K read into its OWN registers (k_tile no longer aliases v_tile),
|
||||
// so this ds_read executes on the LSU *during* the PV MFMA above
|
||||
// rather than waiting for it to retire; the sched_barriers pin it
|
||||
// here. K is now single-warp-group loaded (kFA4WG1LoadsK) so it is
|
||||
// resident at the slot-A barrier, but issuing the read AFTER the
|
||||
// PV gemm call (overlapping the in-flight MFMA) schedules strictly
|
||||
// better than hoisting it ahead of PV — measured ~3-4% faster.
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
K_lds_load(k_rd); // overlaps the PV MFMA (latency hidden)
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -2256,20 +2293,32 @@ struct UnifiedAttentionPipeline
|
||||
if constexpr(cl_p == 0)
|
||||
{
|
||||
// ---- slot A: MATRIX(pi) ‖ (WG1: SOFTMAX) ----
|
||||
// V tile (buf pi) was pre-read into v_tile in the previous
|
||||
// SOFTMAX phase (or the WG0 prime for the first tile).
|
||||
ASM_MARKER("fa4 MATRIX Wave0-3");
|
||||
s_waitcnt_vmcnt<0>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
fa4_vload(pi); // hoisted V read; latency hidden under prefetch
|
||||
prefetch();
|
||||
fa4_matrix(pi);
|
||||
|
||||
// ---- slot B: SOFTMAX(pi) ‖ (WG1: MATRIX) ----
|
||||
// Pre-read the next MATRIX's V tile (buf 1-pi == the buffer
|
||||
// this iteration's prefetch just filled), overlapping the
|
||||
// softmax VALU below; v_tile survives into the next slot-A
|
||||
// MATRIX (PV consumes it via lgkmcnt<0>). The V buffer is
|
||||
// filled cooperatively by WG0's 4 waves, so a wave reads
|
||||
// slices written by its peers: drain the load (vmcnt<0>) and
|
||||
// then cross the phase barrier so all 4 waves' writes are
|
||||
// published BEFORE the read (drain-before-barrier; reading
|
||||
// after only an own-vmcnt races the peers' slices).
|
||||
ASM_MARKER("fa4 SOFTMAX Wave0-3");
|
||||
s_waitcnt_vmcnt<0>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
fa4_vload(number<1>{} - pi);
|
||||
fa4_softmax(pi);
|
||||
|
||||
if(num_total_loop <= ++i_total_loops)
|
||||
@@ -2281,20 +2330,28 @@ struct UnifiedAttentionPipeline
|
||||
// WG1 is one phase ahead (primed by the FA4 prologue): it
|
||||
// softmaxes the tile it QK'd in its previous MATRIX phase
|
||||
// while WG0 runs the MATRIX of the same tile.
|
||||
//
|
||||
// Pre-read this iteration's slot-B MATRIX V tile (buf pi).
|
||||
// That buffer was filled by WG0 and already drained+published
|
||||
// by WG0's drain-before-barrier in its prior SOFTMAX slot, so
|
||||
// the slot-A barrier just crossed guarantees all 4 writer
|
||||
// waves' slices are visible. The read overlaps the softmax
|
||||
// VALU below; v_tile survives into the slot-B MATRIX.
|
||||
ASM_MARKER("fa4 SOFTMAX Wave4-7");
|
||||
s_waitcnt_vmcnt<0>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
fa4_vload(pi);
|
||||
prefetch();
|
||||
fa4_softmax(number<1>{} - pi);
|
||||
|
||||
// ---- slot B: MATRIX(pi) ‖ (WG0: SOFTMAX) ----
|
||||
// v_tile holds buf pi from the slot-A pre-read above.
|
||||
ASM_MARKER("fa4 MATRIX Wave4-7");
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
fa4_vload(pi); // hoisted V read (overlaps barrier exit)
|
||||
fa4_matrix(pi);
|
||||
|
||||
if(num_total_loop <= ++i_total_loops)
|
||||
@@ -2508,6 +2565,11 @@ struct UnifiedAttentionPipeline
|
||||
fmha_alu0(number<0>{});
|
||||
fmha_alu_D_upd();
|
||||
fmha_alu1(number<0>{}); // sp(0).p = P(0)
|
||||
// Prime v_tile for the first MATRIX(0): V buf 0 was loaded by
|
||||
// WG0 in the pre-stage, so its own vmcnt<0> proves residency.
|
||||
// (Stage B reads each subsequent tile's V in the prior SOFTMAX.)
|
||||
s_waitcnt_vmcnt<0>();
|
||||
V_lds_load(number<0>{});
|
||||
while(core_loop_fa4(number<0>{}))
|
||||
;
|
||||
}
|
||||
|
||||
@@ -141,7 +141,11 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
return 16 / sizeof(VDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
// NumWarpsOverride mirrors MakeVDramTileDistribution: the FA4 "WG1 loads K"
|
||||
// path passes NumThreadPerWarpGroup/WarpSize so warp group 1's waves alone
|
||||
// tile the full K buffer (the partner group reads it from shared LDS).
|
||||
template <typename Problem,
|
||||
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps>
|
||||
CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
@@ -149,7 +153,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t NumWarps = NumWarpsOverride;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
@@ -158,7 +162,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
|
||||
static_cast<void>(kBlockSize);
|
||||
|
||||
constexpr index_t N0 = NumIssues;
|
||||
constexpr index_t N1 = LaneGroups;
|
||||
@@ -175,7 +180,13 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
// NumWarpsOverride lets the FA4 per-warp-group ("private V") path request a
|
||||
// distribution where only NumWarps waves cooperate on the load (so each
|
||||
// warp group loads the FULL V tile by itself, into its own LDS buffer, and
|
||||
// its own vmcnt proves residency without waiting on the partner group).
|
||||
// Default = the shape's NumWarps (the original block-cooperative load).
|
||||
template <typename Problem,
|
||||
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
@@ -183,7 +194,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t NumWarps = NumWarpsOverride;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size(); // 64
|
||||
|
||||
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
|
||||
@@ -193,7 +204,10 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
// NumWarps-relative form (NumWarps may be < the full block when the FA4
|
||||
// per-warp-group path requests a private-V distribution).
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
|
||||
static_cast<void>(kBlockSize);
|
||||
|
||||
constexpr index_t N0 = NumIssues; // 8
|
||||
constexpr index_t N1 = LaneGroups; // 2
|
||||
@@ -378,7 +392,19 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
|
||||
static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
|
||||
|
||||
template <typename Problem, ck_tile::index_t IBuf = 0>
|
||||
// WarpIdShift handles a sub-block load issued by a NON-zero warp group via
|
||||
// the raw async path. The raw store derives its LDS offset as
|
||||
// M0 = base + size_per_wave * get_warp_id() (ABSOLUTE warp id 0..7)
|
||||
// so a NumWarps-wide (e.g. 4-wave) layout only tiles correctly for warp ids
|
||||
// 0..NumWarps-1. When warp group g (>0) alone fills the tile, its waves have
|
||||
// absolute ids [g*NumWarps, (g+1)*NumWarps); shifting the descriptor base by
|
||||
// -WarpIdShift*size_per_wave (WarpIdShift = g*NumWarps) maps them back to
|
||||
// effective ids 0..NumWarps-1, i.e. the exact physical layout a warp-group-0
|
||||
// load would produce -- so the (unshifted) read descriptor reads it directly.
|
||||
template <typename Problem,
|
||||
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps,
|
||||
ck_tile::index_t WarpIdShift = 0,
|
||||
ck_tile::index_t IBuf = 0>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeKLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
|
||||
{
|
||||
@@ -388,7 +414,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t NumWarps = NumWarpsOverride;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
@@ -405,7 +431,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
WarpSize /
|
||||
LanesPerK; // how many groups (within a wave), they may load different N, but same K
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
|
||||
static_cast<void>(kBlockSize);
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
@@ -418,7 +445,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<IBuf * GetSingleSmemElementSpaceSize<Problem>() -
|
||||
WarpIdShift*(WarpSize * KVector + kPad)>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -436,7 +464,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
return k_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem,
|
||||
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps>
|
||||
CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
@@ -445,7 +474,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t NumWarps = NumWarpsOverride;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
@@ -458,7 +487,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
|
||||
static_cast<void>(kBlockSize);
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
|
||||
@@ -553,7 +583,12 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
return max(max(SingleKSize, SingleVSize), VLoadDescSize);
|
||||
}
|
||||
|
||||
template <typename Problem, ck_tile::index_t IBuf = 0>
|
||||
// NumWarpsOverride mirrors MakeVDramTileDistribution: the FA4 "WG0 loads V"
|
||||
// path passes NumThreadPerWarpGroup/WarpSize (== 4) so warp group 0's waves
|
||||
// alone tile the full V buffer. Default = the shape's NumWarps (cooperative).
|
||||
template <typename Problem,
|
||||
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps,
|
||||
ck_tile::index_t IBuf = 0>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
|
||||
{
|
||||
@@ -563,7 +598,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t NumWarps = NumWarpsOverride;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
|
||||
@@ -580,7 +615,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
WarpSize /
|
||||
LanesPerK; // how many groups (within a wave), they may load different N, but same K
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
|
||||
static_cast<void>(kBlockSize);
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
@@ -611,7 +647,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
return v_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem,
|
||||
ck_tile::index_t NumWarpsOverride = Problem::UnifiedAttentionShape::NumWarps>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
@@ -620,7 +657,7 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t NumWarps = NumWarpsOverride;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
|
||||
@@ -633,7 +670,8 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector));
|
||||
static_cast<void>(kBlockSize);
|
||||
|
||||
constexpr auto v_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
|
||||
@@ -688,6 +726,63 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
return kv_element_space_size_in_bytes;
|
||||
}
|
||||
|
||||
// FA4 "WG0 loads V" prototype: when the block runs as two warp groups, have
|
||||
// ONLY warp group 0 (waves 0-3) load the full V tile into the shared V LDS
|
||||
// buffer (V's DRAM dist + LDS descriptors use NumThreadPerWarpGroup/WarpSize
|
||||
// == 4 waves so WG0 alone fills the tile). WG1 skips the V DRAM load
|
||||
// entirely. No 2x DRAM, no extra LDS (V stays a shared 2-buffer). This
|
||||
// decouples V's residency from the partner group's cooperative-load shard
|
||||
// (WG0's own vmcnt proves the load) so the V LDS read can later move into
|
||||
// the SOFTMAX phase. K stays block-cooperative across all 8 waves.
|
||||
// Toggle to false to restore the block-cooperative (8-wave) V load.
|
||||
static constexpr bool kFA4WG0LoadsV = true;
|
||||
|
||||
// Symmetric K decoupling: warp group 1 (waves 4-7) alone loads the full K
|
||||
// tile into the shared K LDS buffer; warp group 0 reads it from shared LDS.
|
||||
// Together with kFA4WG0LoadsV this balances DRAM-load work (WG0->V, WG1->K)
|
||||
// and lets each group issue only one tile's load/address instructions.
|
||||
static constexpr bool kFA4WG1LoadsK = true;
|
||||
|
||||
// Number of waves that cooperate on a V DRAM->LDS load. For the 2-warp-group
|
||||
// FA4 path with kFA4WG0LoadsV, this is one warp group's waves (so WG0 alone
|
||||
// fills the tile); otherwise it's the full block (original cooperative load).
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetVLoadNumWarps()
|
||||
{
|
||||
constexpr ck_tile::index_t NumWarpGroups =
|
||||
Problem::kBlockSize / NumThreadPerWarpGroup;
|
||||
if constexpr(kFA4WG0LoadsV && NumWarpGroups == 2)
|
||||
return NumThreadPerWarpGroup / ck_tile::get_warp_size();
|
||||
else
|
||||
return Problem::UnifiedAttentionShape::NumWarps;
|
||||
}
|
||||
|
||||
// K analogue of GetVLoadNumWarps (warp group 1 alone fills the K tile).
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetKLoadNumWarps()
|
||||
{
|
||||
constexpr ck_tile::index_t NumWarpGroups =
|
||||
Problem::kBlockSize / NumThreadPerWarpGroup;
|
||||
if constexpr(kFA4WG1LoadsK && NumWarpGroups == 2)
|
||||
return NumThreadPerWarpGroup / ck_tile::get_warp_size();
|
||||
else
|
||||
return Problem::UnifiedAttentionShape::NumWarps;
|
||||
}
|
||||
|
||||
// Raw-async warp-id shift for the K store (see MakeKLdsStoreBlockDescriptor):
|
||||
// K is loaded by warp group 1, whose absolute warp ids start at one warp
|
||||
// group's worth of waves, so the store base must shift by that many waves.
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetKStoreWarpShift()
|
||||
{
|
||||
constexpr ck_tile::index_t NumWarpGroups =
|
||||
Problem::kBlockSize / NumThreadPerWarpGroup;
|
||||
if constexpr(kFA4WG1LoadsK && NumWarpGroups == 2)
|
||||
return NumThreadPerWarpGroup / ck_tile::get_warp_size(); // WG1's first abs warp id
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user