From e7e6ebc91c8050d372eae106e60ababe84c1bd6a Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 22 Dec 2025 15:35:38 +0000 Subject: [PATCH] Update to GetNumPrefetchV() --- ...ine_qr_ks_vs_whole_k_prefetch_default_policy.hpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index cc342c936d..923857c62c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -30,11 +30,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy constexpr index_t n0_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kN0Sub; constexpr index_t k1_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kK1; - if constexpr(n0_loops >= 4 && k1_loops >= 6) - return 3; - if constexpr(k1_loops >= 4) + if constexpr(Problem::kUseTrLoad) + { + if constexpr(n0_loops >= 4 && k1_loops >= 6) + return 3; return 2; - return 1; + } + else + { + return 2; + }; }; template