diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 923857c62c..3ab239ad20 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -32,9 +32,18 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy if constexpr(Problem::kUseTrLoad) { - if constexpr(n0_loops >= 4 && k1_loops >= 6) - return 3; - return 2; + // kM0 is 64, kN0 is 128, prefetch all k_tiles + if constexpr(IsPreloadWholeNextIterationK()) + { + if constexpr(n0_loops >= 4 && k1_loops >= 6) + return 2; + return 2; + } + else // kM0 is 128, kN0 is 64, prefetch one k_tile + { + // kN0 == 64, try to prefetch more v_tiles + return 2; + }; } else {