diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 2a3d3cd75b..108ccc0425 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -157,36 +157,36 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept GridwiseFwdXdlGemmDescriptor = requires { - { T::ak1 } -> std::convertible_to; - { T::bk1 } -> std::convertible_to; - { T::xdl_params } -> GridwiseXdlGemmDescriptor; +concept GridwiseFwdXdlGemmDescriptor = requires (T t){ + { t.ak1 } -> std::convertible_to; + { t.bk1 } -> std::convertible_to; + { t.xdl_params } -> GridwiseXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept GridwiseBwdXdlGemmDescriptor = requires { - { T::k0_per_block } -> std::convertible_to; - { T::k1 } -> std::convertible_to; - { T::xdl_params } -> GridwiseXdlGemmDescriptor; +concept GridwiseBwdXdlGemmDescriptor = requires (T t){ + { t.k0_per_block } -> std::convertible_to; + { t.k1 } -> std::convertible_to; + { t.xdl_params } -> GridwiseXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseFwdXdlGemm = requires { - { T::gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; +concept SpecifiesGridwiseFwdXdlGemm = requires (T t) { + { t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseBwdXdlGemm = requires { - { T::gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; +concept SpecifiesGridwiseBwdXdlGemm = requires (T t) { + { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise WMMA GEMM info. template -concept SpecifiesGridwiseWmmaGemm = requires { - { T::gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; +concept SpecifiesGridwiseWmmaGemm = requires (T t){ + { t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor; }; // Concept to check if a struct specifies convolution input and output block transfer info. diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 6f3a9e8e78..d7f3b17197 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -161,7 +161,7 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC template consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization SetBwdWeightConvSpecialization() { - constexpr auto specialization = ALGORITHM.bwd_specialization; + constexpr auto specialization = ALGORITHM.bwd_weight_specialization; using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; switch(specialization) { diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp index 975212999b..045efbc385 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -25,7 +25,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::Transfer_4x64x1) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); - +static_assert(cku::SpecifiesGridwiseBwdXdlGemm, "Error"); + using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 2767814e11..4d5ac2cd9e 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -211,7 +211,7 @@ struct ConvSpecializationFwd_ struct ConvSpecializationBwdWeight_ { - ConvSpecialization bwd_specialization; + ConvSpecialization bwd_weight_specialization; }; struct Prefetch_ @@ -400,7 +400,7 @@ struct ConvAlgorithmTemplate : Components... { static_assert(std::is_base_of_v); auto result = *this; - result.bwd_specialization = bwd_spec; + result.bwd_weight_specialization = bwd_spec; return result; } diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 6c1d9ae15f..c3afe2bd4e 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -278,7 +278,7 @@ template <> inline std::string to_string(ConvSpecializationBwdWeight_ t) { std::ostringstream oss; - oss << to_string(t.bwd_specialization); + oss << to_string(t.bwd_weight_specialization); return oss.str(); }