From 2f42e4460f77f935f51b97eee13de3c505356757 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Aug 2024 10:53:42 +0000 Subject: [PATCH] Allow problem types without define kHasDropout attr --- .../block_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 4285f0a55b..80fbc8e380 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -707,19 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy() + GetSmemSizeDropout(); + return GetSmemSizeKV() + GetSmemSizeDropout(0); } else { - return ck_tile::max(GetSmemSizeKV(), GetSmemSizeDropout()); + return ck_tile::max(GetSmemSizeKV(), GetSmemSizeDropout(0)); } } // this method is only available when Problem::kHasDropout is present template CK_TILE_HOST_DEVICE static constexpr std:: - enable_if_t, ck_tile::index_t> - GetSmemSizeDropout() + enable_if_t, ck_tile::index_t> + GetSmemSizeDropout(int) { if constexpr(Problem::kHasDropout) {