From 28221cf01f9a58a1beba0cad678804166bbb4db5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Wed, 15 Oct 2025 07:46:47 +0000 Subject: [PATCH] Value constraint concept example. --- .../include/ck_tile/builder/conv_algorithm.hpp | 11 +++++++++-- .../builder/include/ck_tile/builder/conv_factory.hpp | 5 +++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp index 01e5cebe1a..0dfc7d022c 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp @@ -86,12 +86,19 @@ concept VectorTransferDescriptorAB = requires(T t) { }; // Concept for the C tensor vectors transfer details. +template +concept ValidVectorTransferC = requires { + requires Value.scaler_per_vector > 0 && + Value.m_xdl_per_wave_per_shuffle > 0 && + Value.n_xdl_per_wave_per_shuffle > 0 ; +}; + template concept VectorTransferDescriptorC = requires(T t) { { t.m_xdl_per_wave_per_shuffle } -> std::convertible_to; { t.n_xdl_per_wave_per_shuffle } -> std::convertible_to; - { t.scaler_per_vector } -> std::convertible_to; -}; + { t.scaler_per_vector } -> std::convertible_to; +}; // Concept to check if a struct specifies A Block tranfer info. template diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index edc390baa7..01675641e6 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -406,6 +407,10 @@ struct ConvFactory static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave; static constexpr auto PIPELINE_VERSION = factory_internal::SetBlockGemmPipelineVersion(); + + // Preconditions + static_assert(ValidVectorTransferC); + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //