From 908afb3a551e64ed98a2e400d39a7756cf269310 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 13 Jan 2026 12:04:43 +0800 Subject: [PATCH] [FMHA] Support page_size=1 (linear layout) in batch prefill pipeline (#3545) - Enable page_size=1 support in batch prefill codegen (linear layout only). - Implement per-token page lookup in `kv_offset_array_transform` for page_size=1 to handle 3D input tensors correctly. - Relax `kPageBlockSize` alignment assertion for the page_size=1 case. [ROCm/composable_kernel commit: c9f112b0267625016a58ce3465ee34232c85812b] --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 4 +- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 96 ++++++++++++------- 2 files changed, 62 insertions(+), 38 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index c4c70009d5..37d296aa91 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -36,7 +36,7 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} -SUPPORTED_PAGE_SIZE = [128, 256, 1024] +SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024] SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"] SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"] KV_MEMORY_LAYOUT_ENUM_MAP = { @@ -737,6 +737,8 @@ def get_fwd_blobs( # Generate kernels for both page_size=16 and page_size=1024 for page_size in SUPPORTED_PAGE_SIZE: + if page_size == 1 and pipeline.F_kv_memory_layout != "linear": + continue k = FmhaFwdKernel( F_idx=0, F_hdim=hdim, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 0b47441995..4ee705913b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -24,9 +24,9 @@ template -CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec, - const index_t& stride_kv, - const index_t& page_stride_kv, +CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, + const index_t& stride_token, + const index_t& stride_page_block, const CoordVecType& coord_vec, OffsetVecType& kv_offset_vec, index_t global_seq_offset = 0) @@ -39,47 +39,70 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - const index_t page_offset = global_token_idx & kInPageOffsetMask; - kv_offset_vec[k0] = static_cast(page_vec[page_id]) * page_stride_kv + - static_cast(page_offset) * stride_kv; + const index_t page_id = global_token_idx >> kLog2PageSize; + const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + kv_offset_vec[k0] = static_cast(page_idx[page_id]) * stride_page_block + + static_cast(token_idx_in_page) * stride_token; }); } else { // for v offsets - const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); - const index_t lane0_page_id = - (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; + if constexpr(kLog2PageSize == 0 && + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT) + { + // page size = 1, per-token page lookup. + // Here page_idx maps token_idx -> physical_page_id, so global_seq_offset must be + // the absolute token index within the batch's kv_page_indices slice. + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const long_index_t page_loc = - static_cast(page_vec[lane0_page_id]) * page_stride_kv; + const long_index_t page_base_offset = + static_cast(page_idx[global_token_idx]) * stride_page_block; - static_for<0, kLoopCount, 1>{}([&](auto k0) { - const index_t page_offset = - (global_seq_offset + thread_coord_start + kLoopStart + k0.value) & - kInPageOffsetMask; + kv_offset_vec[k0] = page_base_offset; + }); + } + else + { + // This path handles page_size > 1 and/or non-linear KV layout, where page_idx is + // indexed by page_id (token_idx >> log2_page_size) with an in-page offset. + // Assumes the V tile stays within a single page so lane0 can broadcast the page id. + const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); + const index_t lane0_page_id = + (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - // Vectorized layout offset - // Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize] - // Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize) - const index_t s = page_offset; - const index_t D = stride_kv; + const long_index_t page_base_offset = + static_cast(page_idx[lane0_page_id]) * stride_page_block; - const long_index_t s_offset = - static_cast((s / kVectorSize) * (D * kVectorSize)) + - (s % kVectorSize); + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t token_idx_in_page = + (global_seq_offset + thread_coord_start + kLoopStart + k0.value) & + kInPageOffsetMask; - kv_offset_vec[k0] = page_loc + s_offset; - } - else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT - { - kv_offset_vec[k0] = page_loc + static_cast(page_offset) * stride_kv; - } - }); + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // Vectorized layout offset + // Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize] + // Offset = (token_idx_in_page / kVectorSize) * (HeadDim * kVectorSize) + + // (token_idx_in_page % kVectorSize) + + const long_index_t token_offset = + static_cast((token_idx_in_page / kVectorSize) * + (stride_token * kVectorSize)) + + (token_idx_in_page % kVectorSize); + + kv_offset_vec[k0] = page_base_offset + token_offset; + } + else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT + { + kv_offset_vec[k0] = page_base_offset + + static_cast(token_idx_in_page) * stride_token; + } + }); + } } } @@ -127,9 +150,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static_assert(kPageBlockSize % kN0 == 0, - "V offset assumes each tile stays within a page; kPageBlockSize must be " - "divisible by kN0."); + static_assert(kPageBlockSize % kN0 == 0 || kLog2PageSize == 0, + "Page size must be 1, or a multiple of the tile size (kN0)."); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)