diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 37d296aa91..9a2d727253 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -36,7 +36,7 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} -SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024] +SUPPORTED_PAGE_SIZE = [1, 16, 1024] SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"] SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"] KV_MEMORY_LAYOUT_ENUM_MAP = { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 09b3f07883..c75f5d58c4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -17,12 +17,12 @@ template CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, const index_t& stride_token, @@ -31,6 +31,17 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, OffsetVecType& kv_offset_vec, index_t global_seq_offset = 0) { + static constexpr index_t kLog2PageSize = [] { + index_t shift = 0; + index_t val = kPageBlockSize; + while(val > 1) + { + val >>= 1; + shift++; + } + return shift; + }(); + const index_t& thread_coord_start = coord_vec[kCoordAxis]; constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; if constexpr(kIsKcache) @@ -48,7 +59,10 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, else { // for v offsets - if constexpr(kLog2PageSize == 0 && + // for page_size > 1, the V tile crosses pages when page_size is not a multiple of kN0. + static constexpr bool kVTileCrossesPages = + (kPageBlockSize > 1) && (kPageBlockSize % kN0 != 0); + if constexpr(kPageBlockSize == 1 && kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT) { // page size = 1, per-token page lookup. @@ -64,11 +78,42 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, kv_offset_vec[k0] = page_base_offset; }); } - else + else if constexpr(kVTileCrossesPages) { - // This path handles page_size > 1 and/or non-linear KV layout, where page_idx is - // indexed by page_id (token_idx >> log2_page_size) with an in-page offset. - // Assumes the V tile stays within a single page so lane0 can broadcast the page id. + // V tile crosses multiple pages (e.g., page_size < kN0), so page_id must be computed + // per token. + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t page_id = global_token_idx >> kLog2PageSize; + const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + + const long_index_t page_base_offset = + static_cast(page_idx[page_id]) * stride_page_block; + + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // Vectorized layout uses a packed [token/kVectorSize, head_dim, kVectorSize] + // address pattern. + const long_index_t token_offset = + static_cast((token_idx_in_page / kVectorSize) * + (stride_token * kVectorSize)) + + (token_idx_in_page % kVectorSize); + + kv_offset_vec[k0] = page_base_offset + token_offset; + } + else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT + { + kv_offset_vec[k0] = page_base_offset + + static_cast(token_idx_in_page) * stride_token; + } + }); + } + else // !kVTileCrossesPages + { + // V tile is fully contained in one page, so page_id is shared. + // Use lane0 to compute page_id once and broadcast page_base_offset. const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); const index_t lane0_page_id = (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; @@ -77,8 +122,9 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, static_cast(page_idx[lane0_page_id]) * stride_page_block; static_for<0, kLoopCount, 1>{}([&](auto k0) { + // kLoopStride allows non-unit token spacing in the tile distribution. const index_t token_idx_in_page = - (global_seq_offset + thread_coord_start + kLoopStart + k0.value) & + (global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value) & kInPageOffsetMask; if constexpr(kKVMemoryLayout == @@ -142,7 +188,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr index_t kPageBlockSize = Problem::kPageBlockSize; - static constexpr index_t kLog2PageSize = Problem::kLog2PageSize; static constexpr index_t kVectorSize = Problem::kVectorSize; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; @@ -150,9 +195,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static_assert(kPageBlockSize % kN0 == 0 || kLog2PageSize == 0, - "Page size must be 1, or a multiple of the tile size (kN0)."); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) // only need special care about seq_k padding (oob need set -INF of p instead of zero) @@ -456,12 +498,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(k_coord), 0, kPageBlockSize, - kLog2PageSize, 0, NRepeat, kN0 / NRepeat, kKVMemoryLayout, true, + kN0, kVectorSize>( page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); @@ -501,12 +543,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(v_coord), VPageIndexDim, kPageBlockSize, - kLog2PageSize, 0, V_KRepeat, 1, kKVMemoryLayout, false, + kN0, kVectorSize>( page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); @@ -587,12 +629,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(v_coord), VPageIndexDim, kPageBlockSize, - kLog2PageSize, kK1, V_KRepeat, 1, kKVMemoryLayout, false, + kN0, kVectorSize>( page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); v_dram_window.update_page_idx(v_offsets); @@ -761,12 +803,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(v_coord), VPageIndexDim, kPageBlockSize, - kLog2PageSize, 2 * kK1, V_KRepeat, 1, kKVMemoryLayout, false, + kN0, kVectorSize>( page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); v_dram_window.update_page_idx(v_offsets); @@ -900,12 +942,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(v_coord), VPageIndexDim, kPageBlockSize, - kLog2PageSize, (2 + i_k1.value) * kK1, V_KRepeat, 1, kKVMemoryLayout, false, + kN0, kVectorSize>( page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); v_dram_window.update_page_idx(v_offsets); @@ -957,12 +999,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync decltype(k_coord), 0, kPageBlockSize, - kLog2PageSize, 0, NRepeat, kN0 / NRepeat, kKVMemoryLayout, true, + kN0, kVectorSize>( page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index f9dc94bc65..a489eabb73 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -107,16 +107,6 @@ struct BlockFmhaBatchPrefillPipelineProblem static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive"); static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0, "kPageBlockSize must be power of two"); - static constexpr index_t kLog2PageSize = []() constexpr { - index_t shift = 0; - index_t val = kPageBlockSize_; - while(val > 1) - { - val >>= 1; - shift++; - } - return shift; - }(); static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout; @@ -126,6 +116,8 @@ struct BlockFmhaBatchPrefillPipelineProblem static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0, "kQKHeaddim must be divisible by kVectorSize"); + static_assert(!(kPageBlockSize == 1 && kIsVectorizedLayout), + "page_size=1 only supports linear KV cache layout"); static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0, "kPageBlockSize must be divisible by kVectorSize for vectorized layout"); static_assert(kIsGroupMode_, "Batch prefill requires group mode");