CK-UA: derive kBlockQ at runtime, decouple from variant template

kBlockQ (= kBlockM / num_queries_per_kv) was constexpr in
`UnifiedAttentionShape` / the kernel-traits, forcing one kernel
instance per (kBlockM, num_qpkv) pair even though the matmul tile is
fully determined by kBlockM and kHeadDim. Audit confirmed kBlockQ
only feeds:

  * arithmetic in `unified_attention_kernel.hpp` (loop bounds, Q-tile
    indexing, query_len padding),
  * `pad_tensor_view` size tuples for Q/O/LSE DRAM views,
  * one `mask.IsEdgeTile(... number<kBlockQ>{} ...)` call inside the
    pipeline's per-K-tile mask check.

None of these structurally need a compile-time value:

* `pad_tensor_view` already accepts mixed runtime/compile-time tuple
  elements (e.g. it's passed plain `1` next to `kHeadDimPadded`).
* `IsEdgeTile` only does runtime arithmetic on the tile size; adding a
  runtime overload that accepts `index_t` is trivial (the compile-time
  one now forwards to it).

Wiring:
  * `block_masking.hpp` -- add an `IsEdgeTile(..., index_t tile_h,
    index_t tile_w)` overload; the existing `number<>` overload just
    forwards to it.
  * `unified_attention_pipeline.hpp` -- new optional
    `num_queries_per_kv` arg on the pipeline's `operator()` (default 0
    keeps existing call sites unchanged). Computes
    `kBlockQ_dyn = (num_qpkv > 0) ? (kBlockM / num_qpkv) : kBlockQ`
    once at the top, uses it in the IsEdgeTile call.
  * `unified_attention_kernel.hpp` -- compute
    `const index_t kBlockQ_dyn = kBlockM / kargs.num_queries_per_kv`
    once and replace every per-call `kBlockQ` use with `kBlockQ_dyn`.
    Pass `kargs.num_queries_per_kv` through to the pipeline. The
    debug-only assert(`kBlockQ_dyn == kBlockQ`) keeps the static and
    dynamic values in lock-step until we actually collapse variants.

Perf A/B (b=4..256, sk=120000, MI300):

  d=128 MHA (num_qpkv = 1, runtime div is trivial):
    BW within +/-0.2% across all batch sizes (noise).

  d=64 GQA-8 (num_qpkv = 8, runtime division actually happens):
    speedups 1.28x..2.14x vs Triton -- identical to baseline.

Correctness suite stays at 241/245 (same 4 pre-existing int32-overflow
failures in the d=128 prefill rebased-pointer path).

This is a no-op on perf and unlocks a follow-up where we collapse the
two num_qpkv values per (head_dim, kBlockM) -- e.g. the future d=128
GQA-8 variant can reuse the existing decode_d128_mha_* instances by
just passing a different runtime num_queries_per_kv.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-12 12:01:59 +00:00
parent f5beedb2e9
commit 614afea7eb
3 changed files with 61 additions and 26 deletions

View File

@@ -214,6 +214,17 @@ struct GenericAttentionMask
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto
IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number<TileHeight>, number<TileWidth>) const
{
return IsEdgeTile(i_tile_top, i_tile_left, index_t{TileHeight}, index_t{TileWidth});
}
// Runtime overload. The compile-time variant above wraps this one so call
// sites that pass `number<>{}` keep working unchanged; callers that need a
// runtime tile size (e.g. when kBlockQ is derived from a runtime
// num_queries_per_kv) can call this directly. IsEdgeTile's body only does
// runtime arithmetic, so this is a no-op for current call sites.
CK_TILE_HOST_DEVICE constexpr auto
IsEdgeTile(index_t i_tile_top, index_t i_tile_left, index_t tile_h, index_t tile_w) const
{
// Transform the y index according to repeat_idx
index_t y_eff = i_tile_top / repeat_idx;
@@ -221,15 +232,15 @@ struct GenericAttentionMask
if constexpr(!IsMasking)
{
// TODO: no need to check begin
return (i_tile_left + TileWidth) > x_total;
return (i_tile_left + tile_w) > x_total;
}
else
{
if constexpr(IsLocal)
{
// check top-right corner > x or left-bottom corner < x
index_t i_tile_right = i_tile_left + TileWidth;
index_t i_tile_bottom = y_eff + TileHeight;
index_t i_tile_right = i_tile_left + tile_w;
index_t i_tile_bottom = y_eff + tile_h;
index_t x_end = min(y_eff + x, x_total);
bool top_right_edge = i_tile_right > (y_eff + x);
@@ -242,7 +253,7 @@ struct GenericAttentionMask
else
{
// only need to check top-right corner > x
index_t i_tile_right = i_tile_left + TileWidth;
index_t i_tile_right = i_tile_left + tile_w;
index_t x_end = min(y_eff + x, x_total);
bool top_right_edge = i_tile_right > x_end;

View File

@@ -263,7 +263,13 @@ struct UnifiedAttentionKernel
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
assert(kBlockM / num_queries_per_kv == kBlockQ);
// kBlockQ derived at runtime from num_queries_per_kv. For the variants
// we ship today this matches the compile-time `kBlockQ` from the
// pipeline trait (the assert below catches any disagreement); the
// explicit runtime form is what eventually lets a single kernel
// instantiation cover multiple num_queries_per_kv values.
const index_t kBlockQ_dyn = kBlockM / num_queries_per_kv;
assert(kBlockQ_dyn == kBlockQ);
// Split-KV: each CTA handles one (kv_head, q_block, split) tuple. The
// split index lives in z — when num_splits == 1 (the only z value)
@@ -304,11 +310,11 @@ struct UnifiedAttentionKernel
seq_idx = find_seq_idx(kargs.query_start_len_ptr,
q_block_global_idx,
kargs.num_seqs,
kBlockQ,
kBlockQ_dyn,
true);
const index_t q_block_start_idx =
kargs.query_start_len_ptr[seq_idx] / kBlockQ + seq_idx;
kargs.query_start_len_ptr[seq_idx] / kBlockQ_dyn + seq_idx;
q_block_local_idx =
amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
@@ -319,7 +325,7 @@ struct UnifiedAttentionKernel
cur_batch_query_len =
amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index);
if(q_block_local_idx * kBlockQ >= cur_batch_query_len)
if(q_block_local_idx * kBlockQ_dyn >= cur_batch_query_len)
{
return;
}
@@ -328,13 +334,13 @@ struct UnifiedAttentionKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * kBlockQ);
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * kBlockQ_dyn);
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
index_t _max_seq_prefix_len = amd_wave_read_first_lane(
(context_len + q_block_local_idx * kBlockQ + (kBlockQ - 1) + 1));
(context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1));
if(seq_len < _max_seq_prefix_len)
{
@@ -384,9 +390,9 @@ struct UnifiedAttentionKernel
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
index_t query_len_padded =
amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, kBlockQ) * kBlockQ);
// const bool is_query_len_padded = (cur_batch_query_len % kBlockQ == 0);
index_t query_len_padded = amd_wave_read_first_lane(
integer_divide_ceil(cur_batch_query_len, kBlockQ_dyn) * kBlockQ_dyn);
// const bool is_query_len_padded = (cur_batch_query_len % kBlockQ_dyn == 0);
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
@@ -400,8 +406,9 @@ struct UnifiedAttentionKernel
const auto q_dram_pad =
pad_tensor_view( // aling seqlen with kBlockQ and head dim with kHeadDimPadded
q_dram_base,
// block sizes
make_tuple(number<kBlockQ>{}, 1, kHeadDimPadded),
// block sizes (kBlockQ is runtime here; pad_tensor_view
// accepts a mixed compile-time / runtime tuple)
make_tuple(kBlockQ_dyn, 1, kHeadDimPadded),
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
// kHeadDimPadded)
@@ -509,7 +516,8 @@ struct UnifiedAttentionKernel
kargs.scale_s,
smem_ptr,
static_cast<long_index_t>(kargs.stride_k_cache_1),
static_cast<long_index_t>(kargs.stride_v_cache_1));
static_cast<long_index_t>(kargs.stride_v_cache_1),
num_queries_per_kv);
auto& o_acc_tile = pipeline_result[number<0>{}];
auto& lse_tile = pipeline_result[number<1>{}];
@@ -541,7 +549,7 @@ struct UnifiedAttentionKernel
const auto o_acc_pad = pad_tensor_view(
o_acc_base_view,
make_tuple(kBlockQ, 1, kHeadDimPadded),
make_tuple(kBlockQ_dyn, 1, kHeadDimPadded),
sequence<true, false, kPadHeadDimQ>{});
return transform_tensor_view(
@@ -581,7 +589,7 @@ struct UnifiedAttentionKernel
number<1>{});
const auto lse_acc_pad = pad_tensor_view(
lse_acc_base_view, make_tuple(kBlockQ, 1), sequence<true, false>{});
lse_acc_base_view, make_tuple(kBlockQ_dyn, 1), sequence<true, false>{});
return transform_tensor_view(
lse_acc_pad,
@@ -608,7 +616,7 @@ struct UnifiedAttentionKernel
const auto o_dram_pad =
pad_tensor_view(o_dram_base,
make_tuple(kBlockQ, 1, kHeadDimPadded),
make_tuple(kBlockQ_dyn, 1, kHeadDimPadded),
sequence<true, false, kPadHeadDimQ>{});
return transform_tensor_view(

View File

@@ -188,8 +188,13 @@ struct UnifiedAttentionPipeline
FmhaMask mask,
float scale_s,
void* smem_ptr,
long_index_t k_row_stride = 0,
long_index_t v_row_stride = 0) const
long_index_t k_row_stride = 0,
long_index_t v_row_stride = 0,
// Runtime kBlockQ = kBlockM / num_queries_per_kv. Default of 0 means
// "fall back to the compile-time `kBlockQ` from `UnifiedAttentionShape`"
// so existing callers don't have to change. The kernel template passes
// the runtime value (from kargs) to remove the static dependency.
const index_t num_queries_per_kv = 0) const
{
using namespace ck_tile;
static_assert(
@@ -802,13 +807,20 @@ struct UnifiedAttentionPipeline
});
};
// Resolve kBlockQ at runtime when the caller plumbs in
// num_queries_per_kv (=> kBlockQ = kBlockM / num_qpkv). Fall back to
// the static `kBlockQ` from `UnifiedAttentionShape` when the caller
// passes 0 (back-compat). Stored once, reused per K-tile mask check.
const index_t kBlockQ_dyn =
(num_queries_per_kv > 0) ? (kBlockM / num_queries_per_kv) : kBlockQ;
auto fmha_mask = [&](auto sp_reg_idx) {
if constexpr(FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
i_total_loops * kPageBlockSize,
number<kBlockQ>{},
number<kPageBlockSize>{});
kBlockQ_dyn,
static_cast<index_t>(kPageBlockSize));
if(need_perpixel_check)
{
set_tile_if(sp(sp_reg_idx).sp_compute,
@@ -1253,8 +1265,11 @@ struct UnifiedAttentionPipeline
FmhaMask mask,
float scale_s,
void* smem_ptr,
long_index_t k_row_stride = 0,
long_index_t v_row_stride = 0) const
long_index_t k_row_stride = 0,
long_index_t v_row_stride = 0,
// Forwards to the full-args operator() so callers can plumb in a
// runtime kBlockQ. See the documentation on that overload.
const index_t num_queries_per_kv = 0) const
{
using namespace ck_tile;
@@ -1276,7 +1291,8 @@ struct UnifiedAttentionPipeline
scale_s,
smem_ptr,
k_row_stride,
v_row_stride);
v_row_stride,
num_queries_per_kv);
}
};