diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 4285f0a55b..80fbc8e380 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -707,19 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy() + GetSmemSizeDropout(); + return GetSmemSizeKV() + GetSmemSizeDropout(0); } else { - return ck_tile::max(GetSmemSizeKV(), GetSmemSizeDropout()); + return ck_tile::max(GetSmemSizeKV(), GetSmemSizeDropout(0)); } } // this method is only available when Problem::kHasDropout is present template CK_TILE_HOST_DEVICE static constexpr std:: - enable_if_t, ck_tile::index_t> - GetSmemSizeDropout() + enable_if_t, ck_tile::index_t> + GetSmemSizeDropout(int) { if constexpr(Problem::kHasDropout) {