diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 143a970e01..9ef3a675a7 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -109,6 +109,8 @@ using is_known_at_compile_time = is_static; { \ }; +DEFINE_STATIC_MEMBER_CHECKER(has_bcastpolicy, BCastPolicy); + // FIXME: do we need this anymore? template < typename PY, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 250fe3af38..9f6edbad26 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -21,7 +21,13 @@ struct GemmPipelineAgBgCrImplBase using ALayout = remove_cvref_t{}, AsLayout>>; using BInDataType = remove_cvref_t{}, BsDataType>>; - static constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + static constexpr bool IsBCastPolicyBeforeLDSWrite = [] { + if constexpr(has_bcastpolicy::value) + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + else + return false; + }(); + using BDataType = std::conditional_t; using BLayout = remove_cvref_t{}, BsLayout>>; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index db00f87fd5..a6c02232a8 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -11,16 +11,6 @@ namespace ck_tile { -DEFINE_STATIC_MEMBER_CHECKER(has_bcastpolicy, BCastPolicy); - -template -static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] { - if constexpr(has_bcastpolicy::value) - return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; - else - return false; -}(); - template struct has_a_tile_access_pattern : std::false_type { @@ -90,6 +80,14 @@ struct UniversalGemmBasePolicy static constexpr bool is_b_load_tr = false; #endif + template + static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] { + if constexpr(has_bcastpolicy::value) + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + else + return false; + }(); + static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{};