diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 261c3f103d..59ff83c238 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -251,14 +251,10 @@ class ConvDescription : public Description }; } // namespace conv -/// @brief Helper concept to detect if a type has ConvTraits specialization -template -concept HasConvTraits = requires { typename conv::ConvTraits; }; - /// @brief Factory function to create ConvDescription from a convolution instance type -/// @tparam Instance The convolution instance type (must have InstanceTraits specialization) +/// @tparam Instance The convolution instance type (must have ConvTraits specialization) /// @return A ConvDescription object populated with the instance's configuration details -template +template conv::ConvDescription describe() { using Traits = conv::ConvTraits; diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 29ac49e549..918fd6bdb6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -21,6 +21,57 @@ namespace ck_tile::reflect::conv { +// Forward convolution layout concept - checks for A/B/E layout types +template +concept HasFwdConvLayouts = requires { + typename T::ALayout; + typename T::BLayout; + typename T::ELayout; +}; + +// GEMM specialization concept - checks for kGemmSpecialization member +template +concept HasGemmSpec = requires { + { + T::kGemmSpecialization + } -> std::convertible_to; +}; + +// Data types concept - checks for ADataType member +template +concept HasDataTypes = requires { typename T::ADataType; }; + +// Elementwise operations concept - checks for A/B/CDE elementwise operation types +template +concept HasElementwiseOps = requires { + typename T::AElementwiseOperation; + typename T::BElementwiseOperation; + typename T::CDEElementwiseOperation; +}; + +// Tile parameters concept - checks for tile dimension and transfer members +template +concept HasTileParams = requires { + { T::kKPerBlock } -> std::convertible_to; + { T::kMPerBlock } -> std::convertible_to; + { T::kNPerBlock } -> std::convertible_to; + { T::kAK1 } -> std::convertible_to; + { T::kBK1 } -> std::convertible_to; + T::kCThreadClusterLengths; +}; + +// Comprehensive concept that checks if an instance has all XDL forward convolution traits +// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions +template +concept IsXdlFwdConv = HasFwdConvLayouts && HasGemmSpec && HasDataTypes && + HasElementwiseOps && HasTileParams; + +// Primary concept for checking if a type can be described +// Currently only forward convolutions are supported, but this can be extended +// in the future to include backward data and backward weight convolutions +template +concept HasConvTraits = IsXdlFwdConv>; + // Helper metafunctions to convert from ck enums to builder enums /// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. @@ -35,16 +86,15 @@ constexpr auto convert_pipeline_version() { using enum ck::BlockGemmPipelineVersion; using enum builder::PipelineVersion; - if constexpr(ck_ver == v1) - return V1; - else if constexpr(ck_ver == v2) - return V2; - else if constexpr(ck_ver == v3) - return V3; - else if constexpr(ck_ver == v4) - return V4; - else if constexpr(ck_ver == v5) - return V5; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v3: return V3; + case v4: return V4; + case v5: return V5; + } } /// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. @@ -59,14 +109,14 @@ constexpr auto convert_pipeline_version() { using enum ck::PipelineVersion; using enum builder::PipelineVersion; - if constexpr(ck_ver == v1) - return V1; - else if constexpr(ck_ver == v2) - return V2; - else if constexpr(ck_ver == v4) - return V4; - else if constexpr(ck_ver == weight_only) - return WEIGHT_ONLY; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v4: return V4; + case weight_only: return WEIGHT_ONLY; + } } /// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. @@ -82,10 +132,12 @@ constexpr auto convert_pipeline_scheduler() { using enum ck::BlockGemmPipelineScheduler; using enum builder::PipelineScheduler; - if constexpr(ck_sched == Intrawave) - return INTRAWAVE; - else if constexpr(ck_sched == Interwave) - return INTERWAVE; + + switch(ck_sched) + { + case Intrawave: return INTRAWAVE; + case Interwave: return INTERWAVE; + } } /// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. @@ -101,10 +153,12 @@ constexpr auto convert_pipeline_scheduler() { using enum ck::LoopScheduler; using enum builder::PipelineScheduler; - if constexpr(ck_sched == Default) - return DEFAULT; - else if constexpr(ck_sched == Interwave) - return INTERWAVE; + + switch(ck_sched) + { + case Default: return DEFAULT; + case Interwave: return INTERWAVE; + } } /// @brief Helper structures for organizing trait data with domain-specific naming @@ -213,21 +267,13 @@ constexpr builder::ConvDirection conv_direction() using InstTraits = InstanceTraits; if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) - { return builder::ConvDirection::FORWARD; - } else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) - { return builder::ConvDirection::BACKWARD_DATA; - } else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) - { return builder::ConvDirection::BACKWARD_WEIGHT; - } else - { return builder::ConvDirection::FORWARD; // Default fallback - } } /// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. @@ -242,60 +288,52 @@ constexpr auto conv_spec() if constexpr(requires { InstTraits::kConvForwardSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + using enum builder::ConvFwdSpecialization; - if constexpr(InstTraits::kConvForwardSpecialization == Default) + switch(InstTraits::kConvForwardSpecialization) { - return builder::ConvFwdSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0) - { - return builder::ConvFwdSpecialization::FILTER_1X1_PAD0; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3) - { - return builder::ConvFwdSpecialization::FILTER_3x3; + case Default: return DEFAULT; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter3x3: return FILTER_3x3; } } else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + using enum builder::ConvBwdDataSpecialization; - if constexpr(InstTraits::kConvBwdDataSpecialization == Default) + switch(InstTraits::kConvBwdDataSpecialization) { - return builder::ConvBwdDataSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0; + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; } } else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + using enum builder::ConvBwdWeightSpecialization; - if constexpr(InstTraits::kConvBwdWeightSpecialization == Default) + switch(InstTraits::kConvBwdWeightSpecialization) { - return builder::ConvBwdWeightSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0) - { - return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC) - { - return builder::ConvBwdWeightSpecialization::ODD_C; + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case OddC: return ODD_C; } } } +// Helper variable template to check if CK layout enums match +template +inline constexpr bool layouts_are = + std::is_same_v && std::is_same_v && std::is_same_v; + /// @brief Derives the grouped convolution layout from a device kernel `Instance` type. /// @tparam Instance The device kernel instance type. /// @return An std::array corresponding to the tensor layouts: @@ -304,112 +342,49 @@ constexpr auto conv_spec() /// index 2 -> Output layout template constexpr auto conv_layout() + requires HasFwdConvLayouts> { - using InstTraits = InstanceTraits; - using ALayout = typename InstTraits::ALayout; - using BLayout = typename InstTraits::BLayout; - using ELayout = typename InstTraits::ELayout; + // Helper lambda to construct layout array + auto layouts = [](auto... Ls) { return std::array{Ls...}; }; - namespace ctc = ck::tensor_layout::convolution; + using A = typename InstanceTraits::ALayout; + using B = typename InstanceTraits::BLayout; + using E = typename InstanceTraits::ELayout; + namespace ctl = ck::tensor_layout::convolution; + using enum builder::TensorLayout; - if constexpr(InstTraits::kSpatialDim == 1) + switch(InstanceTraits::kSpatialDim) { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNWC, - builder::TensorLayout::GKXC, - builder::TensorLayout::GNWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NWGC, - builder::TensorLayout::GKXC, - builder::TensorLayout::NWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NGCW, - builder::TensorLayout::GKXC, - builder::TensorLayout::NGKW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NGCW, - builder::TensorLayout::GKCX, - builder::TensorLayout::NGKW}; - } - } - else if constexpr(InstTraits::kSpatialDim == 2) - { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNHWC, - builder::TensorLayout::GKYXC, - builder::TensorLayout::GNHWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NHWGC, - builder::TensorLayout::GKYXC, - builder::TensorLayout::NHWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCHW, - builder::TensorLayout::GKYXC, - builder::TensorLayout::NGKHW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCHW, - builder::TensorLayout::GKCYX, - builder::TensorLayout::NGKHW}; - } - } - else if constexpr(InstTraits::kSpatialDim == 3) - { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNDHWC, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::GNDHWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NDHWGC, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::NDHWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCDHW, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::NGKDHW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCDHW, - builder::TensorLayout::GKCZYX, - builder::TensorLayout::NGKDHW}; - } + case 1: + if constexpr(layouts_are) + return layouts(GNWC, GKXC, GNWK); + if constexpr(layouts_are) + return layouts(NWGC, GKXC, NWGK); + if constexpr(layouts_are) + return layouts(NGCW, GKXC, NGKW); + if constexpr(layouts_are) + return layouts(NGCW, GKCX, NGKW); + break; + case 2: + if constexpr(layouts_are) + return layouts(GNHWC, GKYXC, GNHWK); + if constexpr(layouts_are) + return layouts(NHWGC, GKYXC, NHWGK); + if constexpr(layouts_are) + return layouts(NGCHW, GKYXC, NGKHW); + if constexpr(layouts_are) + return layouts(NGCHW, GKCYX, NGKHW); + break; + case 3: + if constexpr(layouts_are) + return layouts(GNDHWC, GKZYXC, GNDHWK); + if constexpr(layouts_are) + return layouts(NDHWGC, GKZYXC, NDHWGK); + if constexpr(layouts_are) + return layouts(NGCDHW, GKZYXC, NGKDHW); + if constexpr(layouts_are) + return layouts(NGCDHW, GKCZYX, NGKDHW); + break; } } @@ -418,39 +393,26 @@ constexpr auto conv_layout() /// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32). template constexpr builder::DataType conv_data_type() + requires HasDataTypes> { using InstTraits = InstanceTraits; using ADataType = typename InstTraits::ADataType; + using enum builder::DataType; if constexpr(std::is_same_v) - { - return builder::DataType::FP16; - } + return FP16; else if constexpr(std::is_same_v) - { - return builder::DataType::BF16; - } + return BF16; else if constexpr(std::is_same_v) - { - return builder::DataType::FP32; - } + return FP32; else if constexpr(std::is_same_v) - { - return builder::DataType::FP8; - } + return FP8; else if constexpr(std::is_same_v) - { - return builder::DataType::I8; - } + return I8; else if constexpr(std::is_same_v) - { - return builder::DataType::U8; - } + return U8; else - { - // Default fallback - return builder::DataType::FP32; - } + return FP32; // Default fallback } /// @brief Derives the elementwise operation from op type. @@ -459,27 +421,19 @@ constexpr builder::DataType conv_data_type() template constexpr builder::ElementwiseOperation elementwise_op() { + using enum builder::ElementwiseOperation; constexpr std::string_view name = detail::elementwise_op_name(); + if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) - { - return builder::ElementwiseOperation::BIAS_BNORM_CLAMP; - } - else if constexpr(detail::case_insensitive_equal(name, "Clamp")) - { - return builder::ElementwiseOperation::CLAMP; - } - else if constexpr(detail::case_insensitive_equal(name, "Scale")) - { - return builder::ElementwiseOperation::SCALE; - } - else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) - { - return builder::ElementwiseOperation::PASS_THROUGH; - } - else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) - { - return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU; - } + return BIAS_BNORM_CLAMP; + if constexpr(detail::case_insensitive_equal(name, "Clamp")) + return CLAMP; + if constexpr(detail::case_insensitive_equal(name, "Scale")) + return SCALE; + if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + return PASS_THROUGH; + if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) + return SCALEADD_SCALEADD_RELU; } /// @brief Derives a gemm padding from a kernel instance type. @@ -487,6 +441,7 @@ constexpr builder::ElementwiseOperation elementwise_op() /// @return A `builder::GemmPadding` enum value corresponding to kernel padding. template constexpr builder::GemmPadding gemm_spec() + requires HasGemmSpec> { using InstTraits = InstanceTraits; using enum builder::GemmPadding; @@ -494,69 +449,24 @@ constexpr builder::GemmPadding gemm_spec() constexpr auto gemm_spec = InstTraits::kGemmSpecialization; - if constexpr(gemm_spec == Default) + switch(gemm_spec) { - return DEFAULT; - } - else if constexpr(gemm_spec == MPadding) - { - return M_PADDING; - } - else if constexpr(gemm_spec == NPadding) - { - return N_PADDING; - } - else if constexpr(gemm_spec == KPadding) - { - return K_PADDING; - } - else if constexpr(gemm_spec == MNPadding) - { - return MN_PADDING; - } - else if constexpr(gemm_spec == MKPadding) - { - return MK_PADDING; - } - else if constexpr(gemm_spec == NKPadding) - { - return NK_PADDING; - } - else if constexpr(gemm_spec == MNKPadding) - { - return MNK_PADDING; - } - else if constexpr(gemm_spec == OPadding) - { - return O_PADDING; - } - else if constexpr(gemm_spec == MOPadding) - { - return MO_PADDING; - } - else if constexpr(gemm_spec == NOPadding) - { - return NO_PADDING; - } - else if constexpr(gemm_spec == KOPadding) - { - return KO_PADDING; - } - else if constexpr(gemm_spec == MNOPadding) - { - return MNO_PADDING; - } - else if constexpr(gemm_spec == MKOPadding) - { - return MKO_PADDING; - } - else if constexpr(gemm_spec == NKOPadding) - { - return NKO_PADDING; - } - else if constexpr(gemm_spec == MNKOPadding) - { - return MNKO_PADDING; + case Default: return DEFAULT; + case MPadding: return M_PADDING; + case NPadding: return N_PADDING; + case KPadding: return K_PADDING; + case MNPadding: return MN_PADDING; + case MKPadding: return MK_PADDING; + case NKPadding: return NK_PADDING; + case MNKPadding: return MNK_PADDING; + case OPadding: return O_PADDING; + case MOPadding: return MO_PADDING; + case NOPadding: return NO_PADDING; + case KOPadding: return KO_PADDING; + case MNOPadding: return MNO_PADDING; + case MKOPadding: return MKO_PADDING; + case NKOPadding: return NKO_PADDING; + case MNKOPadding: return MNKO_PADDING; } } @@ -571,6 +481,7 @@ struct ConvTraits; /// set of traits directly from a fully-formed device kernel `Instance` type. /// It uses `InstanceTraits` to access the kernel's template parameters. template + requires IsXdlFwdConv> struct ConvTraits { using InstTraits = InstanceTraits;