From 6729989b97f4904991313ee95c203f65380259e7 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 Apr 2026 16:24:19 +0000 Subject: [PATCH] Fix FMHA split-KV for paged-KV with page_block_size < kN0 Cherry-picked from aghamari/unified-attention-decode-opt (fadf0d585). - block_masking.hpp: 5-param GetTileRangeAlongX for GenericAttentionMask - fmha_fwd_splitkv.py: bn0=32 for hdim=64 Made-with: Cursor --- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 4 +- .../ck_tile/ops/fmha/block/block_masking.hpp | 39 +++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index e0ccde8a6b..acc0f46fa9 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -821,7 +821,7 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase): if dtype in ["fp16", "bf16"]: return { "32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "64" : FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), "96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), "128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), # "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), @@ -862,7 +862,7 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase): return { # bm0, bn0, bk0, bn1, bk1, "32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "64" : FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } # fmt: skip diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 4ffb303812..d45b3da603 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -180,6 +180,45 @@ struct GenericAttentionMask } } + template + CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, + number height, + number width, + index_t num_splits, + index_t i_split) const + { + auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); + + const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); + const index_t split_start = x_per_split * i_split; + const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); + + return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), + ck_tile::min(origin_end, split_end)); + } + + template + CK_TILE_HOST_DEVICE constexpr auto GetSinkTileRangeAlongX(index_t i_y, + number height, + number width, + index_t num_splits, + index_t i_split) const + { + auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); + const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); + const index_t split_start = x_per_split * i_split; + const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); + const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0; + const index_t start = ck_tile::max(origin_start, split_start); + const index_t end = ck_tile::min(origin_end, split_end); + const bool is_first_intersecting_split = + (split_start <= sink_seq_end) && (split_end > 0) && (sink > 0); + if(is_first_intersecting_split) + return ck_tile::make_tuple(0, 0, end); + else + return ck_tile::make_tuple(sink_seq_end, start, end); + } + // to get the loop length along Y axis, return index:[start, end), end-start=length // use this if need loop over Y axis tile by tile (like q-seqlen loopover) // TODO: y_end still could be negative, so end-start could be negative(need check)