diff --git a/include/ck_tile/ops/unified_attention/block/block_masking.hpp b/include/ck_tile/ops/unified_attention/block/block_masking.hpp index 68eb6efc89..dd735c7689 100644 --- a/include/ck_tile/ops/unified_attention/block/block_masking.hpp +++ b/include/ck_tile/ops/unified_attention/block/block_masking.hpp @@ -214,6 +214,17 @@ struct GenericAttentionMask template CK_TILE_HOST_DEVICE constexpr auto IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number, number) const + { + return IsEdgeTile(i_tile_top, i_tile_left, index_t{TileHeight}, index_t{TileWidth}); + } + + // Runtime overload. The compile-time variant above wraps this one so call + // sites that pass `number<>{}` keep working unchanged; callers that need a + // runtime tile size (e.g. when kBlockQ is derived from a runtime + // num_queries_per_kv) can call this directly. IsEdgeTile's body only does + // runtime arithmetic, so this is a no-op for current call sites. + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_tile_top, index_t i_tile_left, index_t tile_h, index_t tile_w) const { // Transform the y index according to repeat_idx index_t y_eff = i_tile_top / repeat_idx; @@ -221,15 +232,15 @@ struct GenericAttentionMask if constexpr(!IsMasking) { // TODO: no need to check begin - return (i_tile_left + TileWidth) > x_total; + return (i_tile_left + tile_w) > x_total; } else { if constexpr(IsLocal) { // check top-right corner > x or left-bottom corner < x - index_t i_tile_right = i_tile_left + TileWidth; - index_t i_tile_bottom = y_eff + TileHeight; + index_t i_tile_right = i_tile_left + tile_w; + index_t i_tile_bottom = y_eff + tile_h; index_t x_end = min(y_eff + x, x_total); bool top_right_edge = i_tile_right > (y_eff + x); @@ -242,7 +253,7 @@ struct GenericAttentionMask else { // only need to check top-right corner > x - index_t i_tile_right = i_tile_left + TileWidth; + index_t i_tile_right = i_tile_left + tile_w; index_t x_end = min(y_eff + x, x_total); bool top_right_edge = i_tile_right > x_end; diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 865e62a315..b2593f5f63 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -263,7 +263,13 @@ struct UnifiedAttentionKernel const index_t num_queries_per_kv = kargs.num_queries_per_kv; - assert(kBlockM / num_queries_per_kv == kBlockQ); + // kBlockQ derived at runtime from num_queries_per_kv. For the variants + // we ship today this matches the compile-time `kBlockQ` from the + // pipeline trait (the assert below catches any disagreement); the + // explicit runtime form is what eventually lets a single kernel + // instantiation cover multiple num_queries_per_kv values. + const index_t kBlockQ_dyn = kBlockM / num_queries_per_kv; + assert(kBlockQ_dyn == kBlockQ); // Split-KV: each CTA handles one (kv_head, q_block, split) tuple. The // split index lives in z — when num_splits == 1 (the only z value) @@ -304,11 +310,11 @@ struct UnifiedAttentionKernel seq_idx = find_seq_idx(kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, - kBlockQ, + kBlockQ_dyn, true); const index_t q_block_start_idx = - kargs.query_start_len_ptr[seq_idx] / kBlockQ + seq_idx; + kargs.query_start_len_ptr[seq_idx] / kBlockQ_dyn + seq_idx; q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); @@ -319,7 +325,7 @@ struct UnifiedAttentionKernel cur_batch_query_len = amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index); - if(q_block_local_idx * kBlockQ >= cur_batch_query_len) + if(q_block_local_idx * kBlockQ_dyn >= cur_batch_query_len) { return; } @@ -328,13 +334,13 @@ struct UnifiedAttentionKernel // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * kBlockQ); + const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * kBlockQ_dyn); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); index_t _max_seq_prefix_len = amd_wave_read_first_lane( - (context_len + q_block_local_idx * kBlockQ + (kBlockQ - 1) + 1)); + (context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1)); if(seq_len < _max_seq_prefix_len) { @@ -384,9 +390,9 @@ struct UnifiedAttentionKernel const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; - index_t query_len_padded = - amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, kBlockQ) * kBlockQ); - // const bool is_query_len_padded = (cur_batch_query_len % kBlockQ == 0); + index_t query_len_padded = amd_wave_read_first_lane( + integer_divide_ceil(cur_batch_query_len, kBlockQ_dyn) * kBlockQ_dyn); + // const bool is_query_len_padded = (cur_batch_query_len % kBlockQ_dyn == 0); // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { @@ -400,8 +406,9 @@ struct UnifiedAttentionKernel const auto q_dram_pad = pad_tensor_view( // aling seqlen with kBlockQ and head dim with kHeadDimPadded q_dram_base, - // block sizes - make_tuple(number{}, 1, kHeadDimPadded), + // block sizes (kBlockQ is runtime here; pad_tensor_view + // accepts a mixed compile-time / runtime tuple) + make_tuple(kBlockQ_dyn, 1, kHeadDimPadded), sequence{}); // pads to (seq_len_padded, num_head_q, // kHeadDimPadded) @@ -509,7 +516,8 @@ struct UnifiedAttentionKernel kargs.scale_s, smem_ptr, static_cast(kargs.stride_k_cache_1), - static_cast(kargs.stride_v_cache_1)); + static_cast(kargs.stride_v_cache_1), + num_queries_per_kv); auto& o_acc_tile = pipeline_result[number<0>{}]; auto& lse_tile = pipeline_result[number<1>{}]; @@ -541,7 +549,7 @@ struct UnifiedAttentionKernel const auto o_acc_pad = pad_tensor_view( o_acc_base_view, - make_tuple(kBlockQ, 1, kHeadDimPadded), + make_tuple(kBlockQ_dyn, 1, kHeadDimPadded), sequence{}); return transform_tensor_view( @@ -581,7 +589,7 @@ struct UnifiedAttentionKernel number<1>{}); const auto lse_acc_pad = pad_tensor_view( - lse_acc_base_view, make_tuple(kBlockQ, 1), sequence{}); + lse_acc_base_view, make_tuple(kBlockQ_dyn, 1), sequence{}); return transform_tensor_view( lse_acc_pad, @@ -608,7 +616,7 @@ struct UnifiedAttentionKernel const auto o_dram_pad = pad_tensor_view(o_dram_base, - make_tuple(kBlockQ, 1, kHeadDimPadded), + make_tuple(kBlockQ_dyn, 1, kHeadDimPadded), sequence{}); return transform_tensor_view( diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index d338eb8627..0127939066 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -188,8 +188,13 @@ struct UnifiedAttentionPipeline FmhaMask mask, float scale_s, void* smem_ptr, - long_index_t k_row_stride = 0, - long_index_t v_row_stride = 0) const + long_index_t k_row_stride = 0, + long_index_t v_row_stride = 0, + // Runtime kBlockQ = kBlockM / num_queries_per_kv. Default of 0 means + // "fall back to the compile-time `kBlockQ` from `UnifiedAttentionShape`" + // so existing callers don't have to change. The kernel template passes + // the runtime value (from kargs) to remove the static dependency. + const index_t num_queries_per_kv = 0) const { using namespace ck_tile; static_assert( @@ -802,13 +807,20 @@ struct UnifiedAttentionPipeline }); }; + // Resolve kBlockQ at runtime when the caller plumbs in + // num_queries_per_kv (=> kBlockQ = kBlockM / num_qpkv). Fall back to + // the static `kBlockQ` from `UnifiedAttentionShape` when the caller + // passes 0 (back-compat). Stored once, reused per K-tile mask check. + const index_t kBlockQ_dyn = + (num_queries_per_kv > 0) ? (kBlockM / num_queries_per_kv) : kBlockQ; + auto fmha_mask = [&](auto sp_reg_idx) { if constexpr(FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), i_total_loops * kPageBlockSize, - number{}, - number{}); + kBlockQ_dyn, + static_cast(kPageBlockSize)); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, @@ -1253,8 +1265,11 @@ struct UnifiedAttentionPipeline FmhaMask mask, float scale_s, void* smem_ptr, - long_index_t k_row_stride = 0, - long_index_t v_row_stride = 0) const + long_index_t k_row_stride = 0, + long_index_t v_row_stride = 0, + // Forwards to the full-args operator() so callers can plumb in a + // runtime kBlockQ. See the documentation on that overload. + const index_t num_queries_per_kv = 0) const { using namespace ck_tile; @@ -1276,7 +1291,8 @@ struct UnifiedAttentionPipeline scale_s, smem_ptr, k_row_stride, - v_row_stride); + v_row_stride, + num_queries_per_kv); } };