diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 046d07909a..3e430ff476 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -26,12 +26,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumPrefetchV() { + constexpr index_t n0_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kN0Sub; constexpr index_t k1_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kK1; - // usually kN0 is 128, kK1 is 32/16 - static_assert(k1_loops >= 2, "Check failed!"); - - return 2; + if constexpr(n0_loops >= 4 && k1_loops >= 6) + return 3; + if constexpr(k1_loops >= 4) + return 2; + return 1; }; template