mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
CK-UA: derive kBlockQ at runtime, decouple from variant template
kBlockQ (= kBlockM / num_queries_per_kv) was constexpr in
`UnifiedAttentionShape` / the kernel-traits, forcing one kernel
instance per (kBlockM, num_qpkv) pair even though the matmul tile is
fully determined by kBlockM and kHeadDim. Audit confirmed kBlockQ
only feeds:
* arithmetic in `unified_attention_kernel.hpp` (loop bounds, Q-tile
indexing, query_len padding),
* `pad_tensor_view` size tuples for Q/O/LSE DRAM views,
* one `mask.IsEdgeTile(... number<kBlockQ>{} ...)` call inside the
pipeline's per-K-tile mask check.
None of these structurally need a compile-time value:
* `pad_tensor_view` already accepts mixed runtime/compile-time tuple
elements (e.g. it's passed plain `1` next to `kHeadDimPadded`).
* `IsEdgeTile` only does runtime arithmetic on the tile size; adding a
runtime overload that accepts `index_t` is trivial (the compile-time
one now forwards to it).
Wiring:
* `block_masking.hpp` -- add an `IsEdgeTile(..., index_t tile_h,
index_t tile_w)` overload; the existing `number<>` overload just
forwards to it.
* `unified_attention_pipeline.hpp` -- new optional
`num_queries_per_kv` arg on the pipeline's `operator()` (default 0
keeps existing call sites unchanged). Computes
`kBlockQ_dyn = (num_qpkv > 0) ? (kBlockM / num_qpkv) : kBlockQ`
once at the top, uses it in the IsEdgeTile call.
* `unified_attention_kernel.hpp` -- compute
`const index_t kBlockQ_dyn = kBlockM / kargs.num_queries_per_kv`
once and replace every per-call `kBlockQ` use with `kBlockQ_dyn`.
Pass `kargs.num_queries_per_kv` through to the pipeline. The
debug-only assert(`kBlockQ_dyn == kBlockQ`) keeps the static and
dynamic values in lock-step until we actually collapse variants.
Perf A/B (b=4..256, sk=120000, MI300):
d=128 MHA (num_qpkv = 1, runtime div is trivial):
BW within +/-0.2% across all batch sizes (noise).
d=64 GQA-8 (num_qpkv = 8, runtime division actually happens):
speedups 1.28x..2.14x vs Triton -- identical to baseline.
Correctness suite stays at 241/245 (same 4 pre-existing int32-overflow
failures in the d=128 prefill rebased-pointer path).
This is a no-op on perf and unlocks a follow-up where we collapse the
two num_qpkv values per (head_dim, kBlockM) -- e.g. the future d=128
GQA-8 variant can reuse the existing decode_d128_mha_* instances by
just passing a different runtime num_queries_per_kv.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -214,6 +214,17 @@ struct GenericAttentionMask
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number<TileHeight>, number<TileWidth>) const
|
||||
{
|
||||
return IsEdgeTile(i_tile_top, i_tile_left, index_t{TileHeight}, index_t{TileWidth});
|
||||
}
|
||||
|
||||
// Runtime overload. The compile-time variant above wraps this one so call
|
||||
// sites that pass `number<>{}` keep working unchanged; callers that need a
|
||||
// runtime tile size (e.g. when kBlockQ is derived from a runtime
|
||||
// num_queries_per_kv) can call this directly. IsEdgeTile's body only does
|
||||
// runtime arithmetic, so this is a no-op for current call sites.
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
IsEdgeTile(index_t i_tile_top, index_t i_tile_left, index_t tile_h, index_t tile_w) const
|
||||
{
|
||||
// Transform the y index according to repeat_idx
|
||||
index_t y_eff = i_tile_top / repeat_idx;
|
||||
@@ -221,15 +232,15 @@ struct GenericAttentionMask
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// TODO: no need to check begin
|
||||
return (i_tile_left + TileWidth) > x_total;
|
||||
return (i_tile_left + tile_w) > x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
// check top-right corner > x or left-bottom corner < x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = y_eff + TileHeight;
|
||||
index_t i_tile_right = i_tile_left + tile_w;
|
||||
index_t i_tile_bottom = y_eff + tile_h;
|
||||
index_t x_end = min(y_eff + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > (y_eff + x);
|
||||
@@ -242,7 +253,7 @@ struct GenericAttentionMask
|
||||
else
|
||||
{
|
||||
// only need to check top-right corner > x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_right = i_tile_left + tile_w;
|
||||
index_t x_end = min(y_eff + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > x_end;
|
||||
|
||||
@@ -263,7 +263,13 @@ struct UnifiedAttentionKernel
|
||||
|
||||
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
|
||||
|
||||
assert(kBlockM / num_queries_per_kv == kBlockQ);
|
||||
// kBlockQ derived at runtime from num_queries_per_kv. For the variants
|
||||
// we ship today this matches the compile-time `kBlockQ` from the
|
||||
// pipeline trait (the assert below catches any disagreement); the
|
||||
// explicit runtime form is what eventually lets a single kernel
|
||||
// instantiation cover multiple num_queries_per_kv values.
|
||||
const index_t kBlockQ_dyn = kBlockM / num_queries_per_kv;
|
||||
assert(kBlockQ_dyn == kBlockQ);
|
||||
|
||||
// Split-KV: each CTA handles one (kv_head, q_block, split) tuple. The
|
||||
// split index lives in z — when num_splits == 1 (the only z value)
|
||||
@@ -304,11 +310,11 @@ struct UnifiedAttentionKernel
|
||||
seq_idx = find_seq_idx(kargs.query_start_len_ptr,
|
||||
q_block_global_idx,
|
||||
kargs.num_seqs,
|
||||
kBlockQ,
|
||||
kBlockQ_dyn,
|
||||
true);
|
||||
|
||||
const index_t q_block_start_idx =
|
||||
kargs.query_start_len_ptr[seq_idx] / kBlockQ + seq_idx;
|
||||
kargs.query_start_len_ptr[seq_idx] / kBlockQ_dyn + seq_idx;
|
||||
|
||||
q_block_local_idx =
|
||||
amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
|
||||
@@ -319,7 +325,7 @@ struct UnifiedAttentionKernel
|
||||
cur_batch_query_len =
|
||||
amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index);
|
||||
|
||||
if(q_block_local_idx * kBlockQ >= cur_batch_query_len)
|
||||
if(q_block_local_idx * kBlockQ_dyn >= cur_batch_query_len)
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -328,13 +334,13 @@ struct UnifiedAttentionKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * kBlockQ);
|
||||
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * kBlockQ_dyn);
|
||||
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
|
||||
|
||||
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
|
||||
|
||||
index_t _max_seq_prefix_len = amd_wave_read_first_lane(
|
||||
(context_len + q_block_local_idx * kBlockQ + (kBlockQ - 1) + 1));
|
||||
(context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1));
|
||||
|
||||
if(seq_len < _max_seq_prefix_len)
|
||||
{
|
||||
@@ -384,9 +390,9 @@ struct UnifiedAttentionKernel
|
||||
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
|
||||
|
||||
index_t query_len_padded =
|
||||
amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, kBlockQ) * kBlockQ);
|
||||
// const bool is_query_len_padded = (cur_batch_query_len % kBlockQ == 0);
|
||||
index_t query_len_padded = amd_wave_read_first_lane(
|
||||
integer_divide_ceil(cur_batch_query_len, kBlockQ_dyn) * kBlockQ_dyn);
|
||||
// const bool is_query_len_padded = (cur_batch_query_len % kBlockQ_dyn == 0);
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
@@ -400,8 +406,9 @@ struct UnifiedAttentionKernel
|
||||
const auto q_dram_pad =
|
||||
pad_tensor_view( // aling seqlen with kBlockQ and head dim with kHeadDimPadded
|
||||
q_dram_base,
|
||||
// block sizes
|
||||
make_tuple(number<kBlockQ>{}, 1, kHeadDimPadded),
|
||||
// block sizes (kBlockQ is runtime here; pad_tensor_view
|
||||
// accepts a mixed compile-time / runtime tuple)
|
||||
make_tuple(kBlockQ_dyn, 1, kHeadDimPadded),
|
||||
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
|
||||
// kHeadDimPadded)
|
||||
|
||||
@@ -509,7 +516,8 @@ struct UnifiedAttentionKernel
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
static_cast<long_index_t>(kargs.stride_k_cache_1),
|
||||
static_cast<long_index_t>(kargs.stride_v_cache_1));
|
||||
static_cast<long_index_t>(kargs.stride_v_cache_1),
|
||||
num_queries_per_kv);
|
||||
auto& o_acc_tile = pipeline_result[number<0>{}];
|
||||
auto& lse_tile = pipeline_result[number<1>{}];
|
||||
|
||||
@@ -541,7 +549,7 @@ struct UnifiedAttentionKernel
|
||||
|
||||
const auto o_acc_pad = pad_tensor_view(
|
||||
o_acc_base_view,
|
||||
make_tuple(kBlockQ, 1, kHeadDimPadded),
|
||||
make_tuple(kBlockQ_dyn, 1, kHeadDimPadded),
|
||||
sequence<true, false, kPadHeadDimQ>{});
|
||||
|
||||
return transform_tensor_view(
|
||||
@@ -581,7 +589,7 @@ struct UnifiedAttentionKernel
|
||||
number<1>{});
|
||||
|
||||
const auto lse_acc_pad = pad_tensor_view(
|
||||
lse_acc_base_view, make_tuple(kBlockQ, 1), sequence<true, false>{});
|
||||
lse_acc_base_view, make_tuple(kBlockQ_dyn, 1), sequence<true, false>{});
|
||||
|
||||
return transform_tensor_view(
|
||||
lse_acc_pad,
|
||||
@@ -608,7 +616,7 @@ struct UnifiedAttentionKernel
|
||||
|
||||
const auto o_dram_pad =
|
||||
pad_tensor_view(o_dram_base,
|
||||
make_tuple(kBlockQ, 1, kHeadDimPadded),
|
||||
make_tuple(kBlockQ_dyn, 1, kHeadDimPadded),
|
||||
sequence<true, false, kPadHeadDimQ>{});
|
||||
|
||||
return transform_tensor_view(
|
||||
|
||||
@@ -188,8 +188,13 @@ struct UnifiedAttentionPipeline
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
long_index_t k_row_stride = 0,
|
||||
long_index_t v_row_stride = 0) const
|
||||
long_index_t k_row_stride = 0,
|
||||
long_index_t v_row_stride = 0,
|
||||
// Runtime kBlockQ = kBlockM / num_queries_per_kv. Default of 0 means
|
||||
// "fall back to the compile-time `kBlockQ` from `UnifiedAttentionShape`"
|
||||
// so existing callers don't have to change. The kernel template passes
|
||||
// the runtime value (from kargs) to remove the static dependency.
|
||||
const index_t num_queries_per_kv = 0) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
static_assert(
|
||||
@@ -802,13 +807,20 @@ struct UnifiedAttentionPipeline
|
||||
});
|
||||
};
|
||||
|
||||
// Resolve kBlockQ at runtime when the caller plumbs in
|
||||
// num_queries_per_kv (=> kBlockQ = kBlockM / num_qpkv). Fall back to
|
||||
// the static `kBlockQ` from `UnifiedAttentionShape` when the caller
|
||||
// passes 0 (back-compat). Stored once, reused per K-tile mask check.
|
||||
const index_t kBlockQ_dyn =
|
||||
(num_queries_per_kv > 0) ? (kBlockM / num_queries_per_kv) : kBlockQ;
|
||||
|
||||
auto fmha_mask = [&](auto sp_reg_idx) {
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
i_total_loops * kPageBlockSize,
|
||||
number<kBlockQ>{},
|
||||
number<kPageBlockSize>{});
|
||||
kBlockQ_dyn,
|
||||
static_cast<index_t>(kPageBlockSize));
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(sp(sp_reg_idx).sp_compute,
|
||||
@@ -1253,8 +1265,11 @@ struct UnifiedAttentionPipeline
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
long_index_t k_row_stride = 0,
|
||||
long_index_t v_row_stride = 0) const
|
||||
long_index_t k_row_stride = 0,
|
||||
long_index_t v_row_stride = 0,
|
||||
// Forwards to the full-args operator() so callers can plumb in a
|
||||
// runtime kBlockQ. See the documentation on that overload.
|
||||
const index_t num_queries_per_kv = 0) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
@@ -1276,7 +1291,8 @@ struct UnifiedAttentionPipeline
|
||||
scale_s,
|
||||
smem_ptr,
|
||||
k_row_stride,
|
||||
v_row_stride);
|
||||
v_row_stride,
|
||||
num_queries_per_kv);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user