From d281c519f384a75be2774a33c862b1c6d7c63403 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 15 Dec 2025 15:02:15 +0000 Subject: [PATCH] Adjust in GetNumPrefetchV() --- ...peline_qr_ks_vs_whole_k_prefetch_default_policy.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 046d07909a..3e430ff476 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -26,12 +26,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumPrefetchV() { + constexpr index_t n0_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kN0Sub; constexpr index_t k1_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kK1; - // usually kN0 is 128, kK1 is 32/16 - static_assert(k1_loops >= 2, "Check failed!"); - - return 2; + if constexpr(n0_loops >= 4 && k1_loops >= 6) + return 3; + if constexpr(k1_loops >= 4) + return 2; + return 1; }; template