diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index cd1bfb031b..93b415b4ce 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1641,7 +1641,7 @@ struct FmhaFwdKernel // 2. use more LDS, as we want better memory latency hiding // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the // cache - constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128; + constexpr bool PrefillCase = FmhaPipeline::kM0 > 64; // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 08fc42a471..dce9583fc1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -37,7 +37,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static_assert(kQLoadOnce == Policy::QLoadOnce); - static constexpr bool kKLoadOnce = BlockFmhaShape::kM0 >= 64; + static constexpr bool kKLoadOnce = BlockFmhaShape::kM0 > 64; static constexpr index_t kBlockSize = Problem::kBlockSize;