diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp index 01e5cebe1a..0dfc7d022c 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp @@ -86,12 +86,19 @@ concept VectorTransferDescriptorAB = requires(T t) { }; // Concept for the C tensor vectors transfer details. +template +concept ValidVectorTransferC = requires { + requires Value.scaler_per_vector > 0 && + Value.m_xdl_per_wave_per_shuffle > 0 && + Value.n_xdl_per_wave_per_shuffle > 0 ; +}; + template concept VectorTransferDescriptorC = requires(T t) { { t.m_xdl_per_wave_per_shuffle } -> std::convertible_to; { t.n_xdl_per_wave_per_shuffle } -> std::convertible_to; - { t.scaler_per_vector } -> std::convertible_to; -}; + { t.scaler_per_vector } -> std::convertible_to; +}; // Concept to check if a struct specifies A Block tranfer info. template diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index edc390baa7..01675641e6 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -406,6 +407,10 @@ struct ConvFactory static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave; static constexpr auto PIPELINE_VERSION = factory_internal::SetBlockGemmPipelineVersion(); + + // Preconditions + static_assert(ValidVectorTransferC); + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //