From 12c88731c6615019d5fe846573ba8e836fdec2fb Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Tue, 9 Dec 2025 08:07:35 +0000 Subject: [PATCH] Separate kN0Sub from kK0 to be used for flexible tile tuning for whole_k_prefetch pipeline --- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 2 +- ...mha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 147 +++++++++--------- ..._ks_vs_whole_k_prefetch_default_policy.hpp | 10 +- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 17 +- 4 files changed, 91 insertions(+), 85 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 6cfa862a1a..24496cc755 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1349,7 +1349,7 @@ struct FmhaFwdKernel if constexpr(detail::is_n0loop_pipeline_v) { return pad_tensor_view(k_dram_naive, - make_tuple(number{}, + make_tuple(number{}, number{}), sequence{}); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 23c84cba9e..44d51d13db 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -42,6 +42,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kN0Sub = BlockFmhaShape::kN0Sub; static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; @@ -177,14 +178,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; + constexpr index_t n0_loops = kN0 / kN0Sub; constexpr index_t k1_loops = kN0 / kK1; - // usually kN0 is 128, kK1 is 32/16 + // usually kN0 is 128, kN0Sub/kK1 is 32/16 + static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline"); static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline"); constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV(); + static_assert(n0_loops >= NumPrefetchV, "Check failed!"); static_assert(k1_loops >= NumPrefetchV, "Check failed!"); constexpr bool kPreloadWholeNextIterationK = @@ -196,7 +200,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch // SaccBlockTile size is [kM0, kK1] // PcompBlockTile size is [kM0, kN0] - using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); @@ -227,7 +231,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); @@ -236,13 +240,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch // only prefetch two k tiles to save vgprs consumption auto k_tiles = [&]() { if constexpr(kPreloadWholeNextIterationK) - return statically_indexed_array{}; + return statically_indexed_array{}; else return statically_indexed_array{}; }(); k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); __builtin_amdgcn_sched_barrier(0x00000001); @@ -258,11 +262,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_window, sequence<0, 0>{}, sequence{})); + k_lds_window, sequence<0, 0>{}, sequence{})); // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - using k_lds_read_window_type = - decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + using k_lds_read_window_type = decltype(get_slice_tile( + k_lds_window, sequence<0, 0>{}, sequence{})); statically_indexed_array k_lds_write_windows; statically_indexed_array k_lds_read_windows; @@ -270,11 +274,12 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { k_lds_write_windows[i_buf] = get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{}); + k_lds_read_windows[i_buf] = + get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); // V tile in LDS @@ -371,75 +376,75 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch { if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration { - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}]), + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); - if constexpr(i_k1 < k1_loops - 1) + if constexpr(i_n0 < n0_loops - 1) { - k_tiles[number{}] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); }; - if constexpr(i_k1 < NumPrefetchV) + if constexpr(i_n0 < NumPrefetchV) { - v_tiles[i_k1] = load_tile(v_dram_window); + v_tiles[i_n0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }; - if constexpr(i_k1 == k1_loops - 1) + if constexpr(i_n0 == n0_loops - 1) { // prefetch all k_tiles for next iteration - static_for<0, k1_loops, 1>{}([&](auto ii_k1) { - k_tiles[number{}] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + static_for<0, n0_loops, 1>{}([&](auto ii_n0) { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); }); }; block_sync_lds(); gemm_0(sacc_tile, q_tile, - k_lds_read_windows[number{}]); + k_lds_read_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); auto tmp_tile = cast_tile(sacc_tile); set_slice_tile(pcomp_tile, tmp_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}); + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); } else // the iteration is also the last iteration { - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}]), + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); - if constexpr(i_k1 < k1_loops - 1) + if constexpr(i_n0 < n0_loops - 1) { - k_tiles[number{}] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); }; - if constexpr(i_k1 < NumPrefetchV) + if constexpr(i_n0 < NumPrefetchV) { - v_tiles[i_k1] = load_tile(v_dram_window); + v_tiles[i_n0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }; block_sync_lds(); gemm_0(sacc_tile, q_tile, - k_lds_read_windows[number{}]); + k_lds_read_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); auto tmp_tile = cast_tile(sacc_tile); set_slice_tile(pcomp_tile, tmp_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}); + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); }; } @@ -447,87 +452,87 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch { if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration { - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}]), + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); - if constexpr(i_k1 < NumPrefetchV) + if constexpr(i_n0 < NumPrefetchV) { - v_tiles[i_k1] = load_tile(v_dram_window); + v_tiles[i_n0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); // prefetch k_tile for next iteration - k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + k_tiles[i_n0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); }; // prefetch other k_tiles for next iteration - if constexpr(i_k1 >= NumPrefetchV) + if constexpr(i_n0 >= NumPrefetchV) { - k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + k_tiles[i_n0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); }; block_sync_lds(); gemm_0(sacc_tile, q_tile, - k_lds_read_windows[number{}]); + k_lds_read_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); auto tmp_tile = cast_tile(sacc_tile); set_slice_tile(pcomp_tile, tmp_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}); + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); } else // last iteration { - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}]), + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); - if constexpr(i_k1 < NumPrefetchV) + if constexpr(i_n0 < NumPrefetchV) { - v_tiles[i_k1] = load_tile(v_dram_window); + v_tiles[i_n0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }; block_sync_lds(); gemm_0(sacc_tile, q_tile, - k_lds_read_windows[number{}]); + k_lds_read_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); auto tmp_tile = cast_tile(sacc_tile); set_slice_tile(pcomp_tile, tmp_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}); + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); }; } } else // only preload one unroll of K for next iteration { - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_write_windows[number{}], tile_elementwise_in(k_element_func, k_tiles[I0]), partition_index); __builtin_amdgcn_sched_barrier(0x00000001); - if constexpr(i_k1 < k1_loops - 1) + if constexpr(i_n0 < n0_loops - 1) { k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); }; - if constexpr(i_k1 < NumPrefetchV) + if constexpr(i_n0 < NumPrefetchV) { - v_tiles[i_k1] = load_tile(v_dram_window); + v_tiles[i_n0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }; @@ -535,14 +540,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch block_sync_lds(); - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); auto tmp_tile = cast_tile(sacc_tile); set_slice_tile(pcomp_tile, tmp_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}); + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); } @@ -677,7 +682,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch // check whether first V-LdsBufer overlap with last K-LdsBuffer, // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) { __builtin_amdgcn_s_barrier(); }; @@ -696,7 +701,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch if(seqlen_k_curr < seqlen_k_end) { k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + move_tile_window(k_dram_window, {kN0, 0}); }; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 0ac65ec094..a1cc228812 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -117,7 +117,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy using KDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); @@ -170,7 +170,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); @@ -213,7 +213,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers(); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); @@ -320,7 +320,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy using KDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); @@ -455,7 +455,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy typename Problem::SaccDataType, Problem::kNumGemm0Warps * get_warp_size(), TileGemmShape, typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index ee5238869f..c9d6d4ac60 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -47,23 +47,24 @@ struct TileFmhaShape static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); - static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen - static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen - static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll - static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim - static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll + static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen + static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll + static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size for dividing kN0 + static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim + static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll static constexpr index_t kQKHeaddim = BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at // once (or repeately load Q as a whole tile) - static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); + static_assert(kQKHeaddim % kK0 == 0 || kN0 % kN0Sub == 0, "Check failed!"); static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(); // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; using VLayout = std::conditional_t; + ck_tile::tensor_layout::gemm::RowMajor, + ck_tile::tensor_layout::gemm::ColumnMajor>; }; template