From ca5fb0a3b789f69ce625e095892e50a77ce66878 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Sun, 14 Dec 2025 21:11:46 +0000 Subject: [PATCH] Merge commit '9ac51aa0f44bae776609036f291c3cd2666e84ee' into develop --- .../factory/helpers/ck/conv_tuning_params.hpp | 1 + .../ck_tile/builder/reflect/conv_describe.hpp | 49 +++ .../builder/reflect/conv_description.hpp | 38 +-- .../ck_tile/builder/reflect/conv_traits.hpp | 312 ++++++++++-------- .../ck_tile/builder/reflect/conv_types.hpp | 109 ++++++ .../ck_tile/builder/reflect/description.hpp | 31 ++ .../builder/include/ck_tile/builder/types.hpp | 82 ++++- .../builder/test/conv/ck/test_conv_traits.cpp | 1 + .../builder/test/test_conv_description.cpp | 5 +- ...nstance_string_bwd_weight_grp_conv_xdl.cpp | 8 +- .../test_instance_string_fwd_grp_conv.cpp | 17 +- .../test_instance_string_fwd_grp_conv_dl.cpp | 8 +- ...tance_string_fwd_grp_conv_large_tensor.cpp | 13 +- .../test_instance_string_fwd_grp_conv_v3.cpp | 9 +- ...test_instance_string_fwd_grp_conv_wmma.cpp | 8 +- .../gpu/device/device_base.hpp | 11 + ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 6 + ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 6 + ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 16 + ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 7 + ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 6 + ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 17 + 22 files changed, 549 insertions(+), 211 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_types.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 3ec0a94960..db741f2112 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -153,6 +153,7 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; + case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC; default: throw "Unknown ConvFwdSpecialization"; } } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp new file mode 100644 index 0000000000..fdbfa7c4e1 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// @file +/// @brief Implementation of the describe() function template for convolution kernels + +#pragma once + +#include "ck_tile/builder/reflect/conv_description.hpp" +#include "ck_tile/builder/reflect/conv_traits.hpp" + +namespace ck_tile::reflect { + +/// @brief Factory function to create ConvDescription from a convolution instance type +/// @tparam Instance The convolution instance type (must have ConvTraits) +/// @return A ConvDescription object populated with the instance's configuration details +template +conv::ConvDescription describe() +{ + using Traits = conv::ConvTraits; + + return conv::ConvDescription( + conv::ConvSignatureInfo{ + .spatial_dim = Traits::spatial_dim, + .direction = Traits::direction, + .input_layout = Traits::layout[0], + .weight_layout = Traits::layout[1], + .output_layout = Traits::layout[2], + .data_type = Traits::data_type, + .input_element_op = Traits::input_element_op, + .weight_element_op = Traits::weight_element_op, + .output_element_op = Traits::output_element_op, + }, + conv::GemmAlgorithmInfo{ + .thread_block_size = Traits::thread_block_size, + .tile_dims = Traits::tile_dims, + .warp_gemm = Traits::warp_gemm, + .a_tile_transfer = Traits::a_tile_transfer, + .b_tile_transfer = Traits::b_tile_transfer, + .c_tile_transfer = Traits::c_tile_transfer, + .pipeline_version = Traits::pipeline_version, + .pipeline_scheduler = Traits::pipeline_scheduler, + .conv_specialization = Traits::conv_specialization, + .padding = Traits::gemm_padding, + }, + []() { return reflect::instance_string(); }); +} + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 59ff83c238..46c9bb488e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -25,7 +25,7 @@ #include #include -#include +#include #include #include #include @@ -249,41 +249,7 @@ class ConvDescription : public Description GemmAlgorithmInfo algorithm_; std::function instance_string_getter_; }; + } // namespace conv -/// @brief Factory function to create ConvDescription from a convolution instance type -/// @tparam Instance The convolution instance type (must have ConvTraits specialization) -/// @return A ConvDescription object populated with the instance's configuration details -template -conv::ConvDescription describe() -{ - using Traits = conv::ConvTraits; - - return conv::ConvDescription( - conv::ConvSignatureInfo{ - .spatial_dim = Traits::spatial_dim, - .direction = Traits::direction, - .input_layout = Traits::layout[0], - .weight_layout = Traits::layout[1], - .output_layout = Traits::layout[2], - .data_type = Traits::data_type, - .input_element_op = Traits::input_element_op, - .weight_element_op = Traits::weight_element_op, - .output_element_op = Traits::output_element_op, - }, - conv::GemmAlgorithmInfo{ - .thread_block_size = Traits::thread_block_size, - .tile_dims = Traits::tile_dims, - .warp_gemm = Traits::warp_gemm, - .a_tile_transfer = Traits::a_tile_transfer, - .b_tile_transfer = Traits::b_tile_transfer, - .c_tile_transfer = Traits::c_tile_transfer, - .pipeline_version = Traits::pipeline_version, - .pipeline_scheduler = Traits::pipeline_scheduler, - .conv_specialization = Traits::conv_specialization, - .padding = Traits::gemm_padding, - }, - []() { return reflect::instance_string(); }); -} - } // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index e5a5638887..ab1d1d76ed 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -10,8 +10,8 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/utility/pipeline_enum.hpp" #include "ck/utility/scheduler_enum.hpp" -#include "ck_tile/builder/conv_builder.hpp" #include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/reflect/conv_types.hpp" #include "ck_tile/builder/reflect/instance_traits.hpp" #include "ck_tile/builder/reflect/instance_traits_util.hpp" #include "ck_tile/builder/types.hpp" @@ -161,103 +161,19 @@ constexpr auto convert_pipeline_scheduler() } } -/// @brief Helper structures for organizing trait data with domain-specific naming - -/// @brief Data tile dimensions processed by a workgroup. -/// @details This struct defines the M, N, and K dimensions of the data tile -/// that a single workgroup (thread block) is responsible for processing in the -/// underlying GEMM computation. -struct DataTileInfo -{ - int m; ///< M dimension of the tile processed by the workgroup (MPerBlock). - int n; ///< N dimension of the tile processed by the workgroup (NPerBlock). - int k; ///< K dimension of the tile processed by the workgroup (KPerBlock). -}; - -/// @brief Dimensions for an input data tile transfer. -/// @details Defines the shape of the input tile (A or B matrix) as it is -/// transferred from global memory to LDS. The tile is conceptually divided -/// into k0 and k1 dimensions. -struct InputTileTransferDimensions -{ - int k0; ///< The outer dimension of K, where K = k0 * k1. - int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix. - int k1; ///< The inner dimension of K, often corresponding to the vector load size from global - ///< memory. -}; - -/// @brief Parameters governing the transfer of an input tile. -/// @details This struct holds configuration details for how an input tile is -/// loaded from global memory into LDS, including thread clustering, memory -/// access patterns, and vectorization settings. -struct InputTileTransferParams -{ - int k1; ///< The inner K dimension size, often matching the vectorization width. - std::array - thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how - ///< many threads are arranged on each axis. - std::array thread_cluster_order; ///< The order of thread spatial distribution over the - ///< input tensor dimensions. - std::array src_access_order; ///< The order of accessing input tensor axes (e.g., which - ///< dimension to read first). - int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed - ///< (the contiguous dimension). - int src_scalar_per_vector; ///< The size of the vector access instruction; the number of - ///< elements accessed per thread per instruction. - int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1 - ///< dimension. - bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank - ///< conflicts. -}; - -/// @brief Complete information for an input tile transfer. -/// @details Combines the dimensional information and transfer parameters for -/// a full description of an input tile's journey from global memory to LDS. -struct InputTileTransferInfo -{ - InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile. - InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation. -}; - -/// @brief Parameters for the warp-level GEMM computation. -/// @details Defines the configuration of the GEMM operation performed by each -/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions. -struct WarpGemmParams -{ - int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl). - int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl). - int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per - ///< wavefront (MXdlPerWave). - int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per - ///< wavefront (NXdlPerWave). -}; - -/// @brief Parameters for shuffling data between warps (CShuffle optimization). -/// @details Configures how many MFMA instruction results are processed per -/// wave in each iteration of the CShuffle routine. -struct WarpShuffleParams -{ - int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave - ///< per shuffle iteration. - int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave - ///< per shuffle iteration. -}; - -/// @brief Information for the output tile transfer (CShuffle). -/// @details Describes how the final computed tile (C matrix) is written out from -/// LDS to global memory, including shuffling, thread clustering, and vectorization. -struct OutputTileTransferInfo -{ - WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling. - // m_block, m_wave_per_xdl, n_block, n_wave_per_xdl - std::array thread_cluster_dims; ///< The spatial thread distribution used for storing - ///< data into the output tensor. - int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the - ///< output tensor. -}; - // Helper metafunctions to derive signature information from Instance types +/// @brief Helper function to report unsupported convolution direction with a clear error message. +template +consteval void report_unsupported_conv_direction_error() +{ + throw "Unsupported convolution direction detected!\n" + "The kernel instance does not have a recognized convolution specialization.\n" + "Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or " + "kConvBwdWeightSpecialization.\n" + "Please verify that your kernel instance is properly configured."; +} + /// @brief Derives the convolution direction from a device kernel `Instance` type. /// @tparam Instance The device kernel instance type. /// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). @@ -273,7 +189,10 @@ constexpr builder::ConvDirection conv_direction() else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) return builder::ConvDirection::BACKWARD_WEIGHT; else - return builder::ConvDirection::FORWARD; // Default fallback + { + report_unsupported_conv_direction_error(); + return builder::ConvDirection::FORWARD; // Unreachable + } } /// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. @@ -296,6 +215,7 @@ constexpr auto conv_spec() case Filter1x1Pad0: return FILTER_1X1_PAD0; case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; case Filter3x3: return FILTER_3x3; + case OddC: return ODD_C; } } else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) @@ -334,6 +254,20 @@ template && std::is_same_v && std::is_same_v; +/// @brief Helper function to report unsupported layout combinations with a clear error message. +/// @details This consteval function is designed to fail at compile time with a descriptive +/// error message when an unsupported layout combination is encountered. +template +consteval void report_unsupported_layout_error() +{ + // This will produce a compile-time error with the exception message + throw "Unsupported convolution layout combination detected!\n" + "The combination of ALayout, BLayout, and ELayout template parameters\n" + "is not recognized for the given spatial dimension.\n" + "Please verify that your convolution instance uses a supported layout configuration.\n" + "Check the conv_layout() function for the list of supported layout combinations."; +} + /// @brief Derives the grouped convolution layout from a device kernel `Instance` type. /// @tparam Instance The device kernel instance type. /// @return An std::array corresponding to the tensor layouts: @@ -358,6 +292,8 @@ constexpr auto conv_layout() case 1: if constexpr(layouts_are) return layouts(GNWC, GKXC, GNWK); + if constexpr(layouts_are) + return layouts(GNWC, GKXC, GNWK); if constexpr(layouts_are) return layouts(NWGC, GKXC, NWGK); if constexpr(layouts_are) @@ -368,8 +304,12 @@ constexpr auto conv_layout() case 2: if constexpr(layouts_are) return layouts(GNHWC, GKYXC, GNHWK); + if constexpr(layouts_are) + return layouts(GNHWC, GKYXC, GNHWK); if constexpr(layouts_are) return layouts(NHWGC, GKYXC, NHWGK); + if constexpr(layouts_are) + return layouts(NHWGC, GKYXC, NHWGK); if constexpr(layouts_are) return layouts(NGCHW, GKYXC, NGKHW); if constexpr(layouts_are) @@ -378,6 +318,8 @@ constexpr auto conv_layout() case 3: if constexpr(layouts_are) return layouts(GNDHWC, GKZYXC, GNDHWK); + if constexpr(layouts_are) + return layouts(GNDHWC, GKZYXC, GNDHWK); if constexpr(layouts_are) return layouts(NDHWGC, GKZYXC, NDHWGK); if constexpr(layouts_are) @@ -386,11 +328,31 @@ constexpr auto conv_layout() return layouts(NGCDHW, GKCZYX, NGKDHW); break; } + + // If we reach here, the layout combination is not supported + // Call consteval function to trigger a compile-time error with a clear message + report_unsupported_layout_error::kSpatialDim>(); + + // This return is unreachable but needed to satisfy the compiler + return layouts(GNHWC, GKYXC, GNHWK); +} + +/// @brief Helper function to report unsupported data type with a clear error message. +template +consteval void report_unsupported_data_type_error() +{ + throw "Unsupported data type detected!\n" + "The ADataType is not recognized.\n" + "Supported types are: ck::half_t (FP16), ck::Tuple (FP16_FP16), " + "ck::bhalf_t (BF16), ck::Tuple (BF16_BF16), float (FP32), " + "ck::Tuple (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t " + "(BF8), " + "int8_t (I8), ck::Tuple (I8_I8), uint8_t (U8).\n" + "Please verify that your kernel instance uses a supported data type."; } /// @brief Derives the data type from a device kernel `Instance` type. -/// @tparam Instance The device kernel instance type. -/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32). +/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8). template constexpr builder::DataType conv_data_type() requires HasDataTypes> @@ -401,18 +363,50 @@ constexpr builder::DataType conv_data_type() if constexpr(std::is_same_v) return FP16; + else if constexpr(std::is_same_v>) + return FP16_FP16; else if constexpr(std::is_same_v) return BF16; + else if constexpr(std::is_same_v>) + return BF16_BF16; else if constexpr(std::is_same_v) return FP32; + else if constexpr(std::is_same_v>) + return FP32_FP32; + else if constexpr(std::is_same_v) + return FP64; else if constexpr(std::is_same_v) return FP8; + else if constexpr(std::is_same_v) + return BF8; + else if constexpr(std::is_same_v) + return BF8; else if constexpr(std::is_same_v) return I8; + else if constexpr(std::is_same_v>) + return I8_I8; else if constexpr(std::is_same_v) return U8; else - return FP32; // Default fallback + { + report_unsupported_data_type_error(); + return FP32; // Unreachable + } +} + +/// @brief Helper function to report unsupported elementwise operation with a clear error message. +template +consteval void report_unsupported_elementwise_op_error() +{ + throw "Unsupported elementwise operation detected!\n" + "The elementwise operation type is not recognized.\n" + "Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, " + "BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, " + "ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, " + "UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, " + "Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, " + "UnaryConvert.\n" + "Please verify that your kernel instance uses a supported elementwise operation."; } /// @brief Derives the elementwise operation from op type. @@ -424,16 +418,83 @@ constexpr builder::ElementwiseOperation elementwise_op() using enum builder::ElementwiseOperation; constexpr std::string_view name = detail::elementwise_op_name(); - if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) + if constexpr(detail::case_insensitive_equal(name, "AddClamp")) + return ADD_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd")) + return ADD_RELU_ADD; + else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) return BIAS_BNORM_CLAMP; - if constexpr(detail::case_insensitive_equal(name, "Clamp")) + else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) + return BILINEAR; + else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp")) + return BIAS_BNORM_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Clamp")) return CLAMP; - if constexpr(detail::case_insensitive_equal(name, "Scale")) + else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale")) + return CONV_INVSCALE; + else if constexpr(detail::case_insensitive_equal(name, "ConvScale")) + return CONV_SCALE; + else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd")) + return CONV_SCALE_ADD; + else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu")) + return CONV_SCALE_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Scale")) return SCALE; - if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd")) + return SCALE_ADD; + else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) return PASS_THROUGH; - if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) + else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) return SCALEADD_SCALEADD_RELU; + else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp")) + return DYNAMIC_UNARY_OP; + else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp")) + return UNARY_COMBINED_OP; + else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp")) + return ACTIVATION_MUL2_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp")) + return ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp")) + return ADD_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp")) + return ADD_ACTIVATION_MUL2_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp")) + return ADD_MUL_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp")) + return ADD_MUL2_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert")) + return UNARY_CONVERT; + else if constexpr(detail::case_insensitive_equal(name, "Logistic")) + return LOGISTIC; + else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu")) + return CLIPPED_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Swish")) + return SWISH; + else if constexpr(detail::case_insensitive_equal(name, "Elu")) + return ELU; + else if constexpr(detail::case_insensitive_equal(name, "Power")) + return POWER; + else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu")) + return LEAKY_RELU; + else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs")) + return UNARY_ABS; + else if constexpr(detail::case_insensitive_equal(name, "Relu")) + return RELU; + else if constexpr(detail::case_insensitive_equal(name, "SoftRelu")) + return SOFT_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Sigmoid")) + return SIGMOID; + else if constexpr(detail::case_insensitive_equal(name, "TanH")) + return TANH; + else if constexpr(detail::case_insensitive_equal(name, "Gelu")) + return GELU; + else if constexpr(detail::case_insensitive_equal(name, "Silu")) + return SILU; + else + { + report_unsupported_elementwise_op_error(); + return PASS_THROUGH; // Unreachable + } } /// @brief Derives a gemm padding from a kernel instance type. @@ -606,45 +667,4 @@ struct ConvTraits static constexpr auto pipeline_scheduler = get_pipeline_scheduler(); }; -/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type. -/// @details This specialization provides backward compatibility for reflecting -/// on kernels defined via the `ConvBuilder` interface. It works by first -/// creating the `Instance` via the builder, and then delegating -/// all trait extraction to the `ConvTraits` specialization. -template -struct ConvTraits> -{ - using Instance = typename builder::ConvBuilder::Instance; - - // Delegate to Instance-based ConvTraits - using InstanceConvTraits = ConvTraits; - - // Forward all members from Instance-based traits - static constexpr int spatial_dim = InstanceConvTraits::spatial_dim; - static constexpr builder::ConvDirection direction = InstanceConvTraits::direction; - static constexpr auto layout = InstanceConvTraits::layout; - static constexpr builder::DataType data_type = InstanceConvTraits::data_type; - - static constexpr builder::ElementwiseOperation input_element_op = - InstanceConvTraits::input_element_op; - static constexpr builder::ElementwiseOperation weight_element_op = - InstanceConvTraits::weight_element_op; - static constexpr builder::ElementwiseOperation output_element_op = - InstanceConvTraits::output_element_op; - - static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding; - static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization; - - static constexpr int thread_block_size = InstanceConvTraits::thread_block_size; - static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims; - static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer; - static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer; - static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm; - static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer; - static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version; - static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler; -}; - } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_types.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_types.hpp new file mode 100644 index 0000000000..bb98455617 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_types.hpp @@ -0,0 +1,109 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// @file +/// @brief Type definitions for convolution reflection +/// +/// This file contains the type definitions used by both conv_traits.hpp and conv_description.hpp +/// to avoid circular dependencies. + +#pragma once + +#include + +namespace ck_tile::reflect::conv { + +/// @brief Data tile dimensions processed by a workgroup. +/// @details This struct defines the M, N, and K dimensions of the data tile +/// that a single workgroup (thread block) is responsible for processing in the +/// underlying GEMM computation. +struct DataTileInfo +{ + int m; ///< M dimension of the tile processed by the workgroup (MPerBlock). + int n; ///< N dimension of the tile processed by the workgroup (NPerBlock). + int k; ///< K dimension of the tile processed by the workgroup (KPerBlock). +}; + +/// @brief Dimensions for an input data tile transfer. +/// @details Defines the shape of the input tile (A or B matrix) as it is +/// transferred from global memory to LDS. The tile is conceptually divided +/// into k0 and k1 dimensions. +struct InputTileTransferDimensions +{ + int k0; ///< The outer dimension of K, where K = k0 * k1. + int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix. + int k1; ///< The inner dimension of K, often corresponding to the vector load size from global + ///< memory. +}; + +/// @brief Parameters governing the transfer of an input tile. +/// @details This struct holds configuration details for how an input tile is +/// loaded from global memory into LDS, including thread clustering, memory +/// access patterns, and vectorization settings. +struct InputTileTransferParams +{ + int k1; ///< The inner K dimension size, often matching the vectorization width. + std::array + thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how + ///< many threads are arranged on each axis. + std::array thread_cluster_order; ///< The order of thread spatial distribution over the + ///< input tensor dimensions. + std::array src_access_order; ///< The order of accessing input tensor axes (e.g., which + ///< dimension to read first). + int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed + ///< (the contiguous dimension). + int src_scalar_per_vector; ///< The size of the vector access instruction; the number of + ///< elements accessed per thread per instruction. + int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1 + ///< dimension. + bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank + ///< conflicts. +}; + +/// @brief Complete information for an input tile transfer. +/// @details Combines the dimensional information and transfer parameters for +/// a full description of an input tile's journey from global memory to LDS. +struct InputTileTransferInfo +{ + InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile. + InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation. +}; + +/// @brief Parameters for the warp-level GEMM computation. +/// @details Defines the configuration of the GEMM operation performed by each +/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions. +struct WarpGemmParams +{ + int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl). + int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl). + int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per + ///< wavefront (MXdlPerWave). + int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per + ///< wavefront (NXdlPerWave). +}; + +/// @brief Parameters for shuffling data between warps (CShuffle optimization). +/// @details Configures how many MFMA instruction results are processed per +/// wave in each iteration of the CShuffle routine. +struct WarpShuffleParams +{ + int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave + ///< per shuffle iteration. + int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave + ///< per shuffle iteration. +}; + +/// @brief Information for the output tile transfer (CShuffle). +/// @details Describes how the final computed tile (C matrix) is written out from +/// LDS to global memory, including shuffling, thread clustering, and vectorization. +struct OutputTileTransferInfo +{ + WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling. + // m_block, m_wave_per_xdl, n_block, n_wave_per_xdl + std::array thread_cluster_dims; ///< The spatial thread distribution used for storing + ///< data into the output tensor. + int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the + ///< output tensor. +}; + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/description.hpp b/experimental/builder/include/ck_tile/builder/reflect/description.hpp index c3a38964a7..6a7b2513be 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/description.hpp @@ -20,6 +20,11 @@ namespace ck_tile::reflect { class Description { public: + Description() = default; + Description(const Description&) = default; + Description(Description&&) = default; + Description& operator=(const Description&) = default; + Description& operator=(Description&&) = default; /// @brief Virtual destructor for proper cleanup of derived classes virtual ~Description() = default; @@ -36,4 +41,30 @@ class Description virtual std::string instance_string() const = 0; }; +/// @brief A specialized Description that only supports instance_string() +/// This is a helper class for kernels that don't yet have full ConvDescription support. +/// The brief() and detailed() methods return "not supported" placeholders. +class InstanceStringDescription : public Description +{ + public: + /// @brief Construct with an instance string + /// @param instance The instance string to store + explicit InstanceStringDescription(std::string instance) : instance_(std::move(instance)) {} + + /// @brief Returns "not supported" as brief descriptions are not implemented + /// @return A placeholder string indicating the feature is not supported + std::string brief() const override { return "not supported"; } + + /// @brief Returns "not supported" as detailed descriptions are not implemented + /// @return A placeholder string indicating the feature is not supported + std::string detailed() const override { return "not supported"; } + + /// @brief Returns the stored instance string + /// @return The instance string provided during construction + std::string instance_string() const override { return instance_; } + + private: + std::string instance_; ///< The stored instance string +}; + } // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index fb732df55e..f7386720b3 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -11,15 +11,22 @@ namespace ck_tile::builder { +// TODO: Handle tuple types and FP8/BF8 properly enum class DataType { UNDEFINED_DATA_TYPE = 0, FP32, + FP32_FP32, FP16, + FP16_FP16, BF16, + BF16_BF16, FP8, + BF8, + FP64, INT32, I8, + I8_I8, U8 }; @@ -102,13 +109,44 @@ enum class ConvDirection }; // Fused element-wise operations. +// TODO: Generalize design rather than enumerating all possible ops. enum class ElementwiseOperation { + ADD_CLAMP, + ADD_RELU_ADD, + ACTIVATION_MUL2_CLAMP, + ACTIVATION_MUL_CLAMP, + ADD_ACTIVATION_MUL_CLAMP, + ADD_ACTIVATION_MUL2_CLAMP, + ADD_MUL_ACTIVATION_MUL_CLAMP, + ADD_MUL2_ACTIVATION_MUL_CLAMP, BIAS_BNORM_CLAMP, + BILINEAR, SCALE, + SCALE_ADD, CLAMP, + CONV_INVSCALE, + CONV_SCALE, + CONV_SCALE_ADD, + CONV_SCALE_RELU, PASS_THROUGH, - SCALEADD_SCALEADD_RELU + SCALEADD_SCALEADD_RELU, + DYNAMIC_UNARY_OP, + UNARY_COMBINED_OP, + UNARY_CONVERT, + LOGISTIC, + CLIPPED_RELU, + SWISH, + ELU, + POWER, + LEAKY_RELU, + UNARY_ABS, + RELU, + SOFT_RELU, + SIGMOID, + TANH, + GELU, + SILU }; // Enums for pipeline versions & schedulers @@ -160,7 +198,8 @@ enum class ConvFwdSpecialization DEFAULT, FILTER_1X1_PAD0, FILTER_1X1_STRIDE1_PAD0, - FILTER_3x3 + FILTER_3x3, + ODD_C }; // Enums for the backward data convolution specialization. @@ -219,11 +258,17 @@ inline std::string_view toString(DataType dt) switch(dt) { case FP16: return "FP16"; + case FP16_FP16: return "FP16_FP16"; case FP32: return "FP32"; + case FP32_FP32: return "FP32_FP32"; case BF16: return "BF16"; + case BF16_BF16: return "BF16_BF16"; case FP8: return "FP8"; + case BF8: return "BF8"; + case FP64: return "FP64"; case INT32: return "INT32"; case I8: return "I8"; + case I8_I8: return "I8_I8"; case U8: return "U8"; case UNDEFINED_DATA_TYPE: return "UNDEFINED_DATA_TYPE"; default: return "Unknown"; @@ -247,11 +292,41 @@ inline std::string_view toString(ElementwiseOperation op) using enum ElementwiseOperation; switch(op) { + case ADD_CLAMP: return "ADD_CLAMP"; + case ADD_RELU_ADD: return "ADD_RELU_ADD"; + case ACTIVATION_MUL2_CLAMP: return "ACTIVATION_MUL2_CLAMP"; + case ACTIVATION_MUL_CLAMP: return "ACTIVATION_MUL_CLAMP"; + case ADD_ACTIVATION_MUL_CLAMP: return "ADD_ACTIVATION_MUL_CLAMP"; + case ADD_ACTIVATION_MUL2_CLAMP: return "ADD_ACTIVATION_MUL2_CLAMP"; + case ADD_MUL_ACTIVATION_MUL_CLAMP: return "ADD_MUL_ACTIVATION_MUL_CLAMP"; + case ADD_MUL2_ACTIVATION_MUL_CLAMP: return "ADD_MUL2_ACTIVATION_MUL_CLAMP"; + case BIAS_BNORM_CLAMP: return "BIAS_BNORM_CLAMP"; + case BILINEAR: return "BILINEAR"; case CLAMP: return "CLAMP"; case SCALE: return "SCALE"; + case SCALE_ADD: return "SCALE_ADD"; + case CONV_INVSCALE: return "CONV_INVSCALE"; + case CONV_SCALE: return "CONV_SCALE"; + case CONV_SCALE_ADD: return "CONV_SCALE_ADD"; + case CONV_SCALE_RELU: return "CONV_SCALE_RELU"; case PASS_THROUGH: return "PASS_THROUGH"; - case BIAS_BNORM_CLAMP: return "BIAS_BNORM_CLAMP"; case SCALEADD_SCALEADD_RELU: return "SCALEADD_SCALEADD_RELU"; + case DYNAMIC_UNARY_OP: return "DYNAMIC_UNARY_OP"; + case UNARY_COMBINED_OP: return "UNARY_COMBINED_OP"; + case UNARY_CONVERT: return "UNARY_CONVERT"; + case LOGISTIC: return "LOGISTIC"; + case CLIPPED_RELU: return "CLIPPED_RELU"; + case SWISH: return "SWISH"; + case ELU: return "ELU"; + case POWER: return "POWER"; + case LEAKY_RELU: return "LEAKY_RELU"; + case UNARY_ABS: return "UNARY_ABS"; + case RELU: return "RELU"; + case SOFT_RELU: return "SOFT_RELU"; + case SIGMOID: return "SIGMOID"; + case TANH: return "TANH"; + case GELU: return "GELU"; + case SILU: return "SILU"; default: return "Unknown"; } } @@ -305,6 +380,7 @@ inline std::string_view toString(ConvFwdSpecialization spec) case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0"; case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0"; case FILTER_3x3: return "FILTER_3x3"; + case ODD_C: return "ODD_C"; default: return "Unknown"; } } diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index d052aba548..d5661ad67b 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index ace9ce0239..158cb2668f 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -4,8 +4,9 @@ #include #include -#include -#include +#include "ck_tile/builder/conv_builder.hpp" +#include "ck_tile/builder/reflect/conv_description.hpp" +#include "ck_tile/builder/reflect/conv_describe.hpp" #include "testing_utils.hpp" #include "impl/conv_signature_types.hpp" #include "impl/conv_algorithm_types.hpp" diff --git a/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_xdl.cpp b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_xdl.cpp index 88a57a3735..38e79a2eb5 100644 --- a/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_xdl.cpp +++ b/experimental/builder/test/test_instance_string_bwd_weight_grp_conv_xdl.cpp @@ -72,14 +72,16 @@ std::string expected_str = "DeviceGroupedConvBwdWeight_Xdl_CShuffle" ",1" // MaxTransposeTransferSrcScalarPerVector ",1>"; // MaxTransposeTransferDstScalarPerVector -// Test GetInstanceString through base class pointer for backward weight XDL variant -TEST(InstanceString, GetInstanceStringReturnsCorrectValueForBwdWeightGrpConvXdl) +// Test describe() through base class pointer for backward weight XDL variant +TEST(InstanceString, DescribeReturnsCorrectValueForBwdWeightGrpConvXdl) { using BaseClass = ck::tensor_operation::device::BaseOperator; DeviceInstance device_instance; BaseClass* base_ptr = &device_instance; - EXPECT_EQ(base_ptr->GetInstanceString(), expected_str); + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); } // TODO: Add DescriptionReturnsCorrectValueForBwdWeightGrpConvXdl test once ckr::describe supports diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp index 35f3db1469..bb67e18087 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv.cpp @@ -2,10 +2,11 @@ // SPDX-License-Identifier: MIT #include "gtest/gtest.h" -#include "ck_tile/builder/reflect/instance_traits.hpp" -#include "ck_tile/builder/reflect/conv_description.hpp" -#include "ck/tensor_operation/gpu/device/device_base.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include +#include +#include +#include +#include namespace { @@ -77,14 +78,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" ",Default" // LoopScheduler ",1>"; // NumGroupsToMerge -// Test GetInstanceString through base class pointer for standard XDL variant -TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConv) +// Test describe() through base class pointer for standard XDL variant +TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConv) { using BaseClass = ck::tensor_operation::device::BaseOperator; DeviceInstance device_instance; BaseClass* base_ptr = &device_instance; - EXPECT_EQ(base_ptr->GetInstanceString(), expected_str); + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); } TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConv) diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv_dl.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv_dl.cpp index 4f018cca11..cc585342c6 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv_dl.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv_dl.cpp @@ -71,14 +71,16 @@ std::string expected_str = "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK" ",5" // CThreadTransferSrcDstVectorDim ",1>"; // CThreadTransferDstScalarPerVector -// Test GetInstanceString through base class pointer for DL variant -TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvDl) +// Test describe() through base class pointer for DL variant +TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvDl) { using BaseClass = ck::tensor_operation::device::BaseOperator; DeviceInstance device_instance; BaseClass* base_ptr = &device_instance; - EXPECT_EQ(base_ptr->GetInstanceString(), expected_str); + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); } // TODO: Add DescriptionReturnsCorrectValueForFwdGrpConvDl test once ckr::describe supports DL diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp index 26b50bea6d..2f4e15e35c 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv_large_tensor.cpp @@ -2,10 +2,11 @@ // SPDX-License-Identifier: MIT #include -#include +#include #include -#include +#include #include +#include namespace { @@ -76,14 +77,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Ten ",fp16" // BComputeDataType ",Default>"; // LoopScheduler -// Test GetInstanceString through base class pointer for large tensor variant -TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvLargeTensor) +// Test describe() through base class pointer for large tensor variant +TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvLargeTensor) { using BaseClass = ck::tensor_operation::device::BaseOperator; DeviceInstance device_instance; BaseClass* base_ptr = &device_instance; - EXPECT_EQ(base_ptr->GetInstanceString(), expected_str); + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); } TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConvLargeTensor) diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp index 604667dd10..ccfa4c7197 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -78,14 +79,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" ",fp16" // BComputeDataType ",false>"; // DirectLoad -// Test GetInstanceString through base class pointer for V3 variant -TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvV3) +// Test describe() through base class pointer for V3 variant +TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvV3) { using BaseClass = ck::tensor_operation::device::BaseOperator; DeviceInstance device_instance; BaseClass* base_ptr = &device_instance; - EXPECT_EQ(base_ptr->GetInstanceString(), expected_str); + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); } TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConvV3) diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv_wmma.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv_wmma.cpp index 717b770c52..1b82a37a1c 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv_wmma.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv_wmma.cpp @@ -76,14 +76,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle" ",Default" // LoopSched ",v1>"; // PipelineVer -// Test GetInstanceString through base class pointer for WMMA variant -TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvWmma) +// Test describe() through base class pointer for WMMA variant +TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvWmma) { using BaseClass = ck::tensor_operation::device::BaseOperator; DeviceInstance device_instance; BaseClass* base_ptr = &device_instance; - EXPECT_EQ(base_ptr->GetInstanceString(), expected_str); + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); } // TODO: Add DescriptionReturnsCorrectValueForFwdGrpConvWmma test once ckr::describe supports WMMA diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index ec623db6f7..3e37aac86e 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -8,8 +8,13 @@ #include #include #include +#include #include "ck/stream_config.hpp" + +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#endif #endif #include "ck/utility/get_id.hpp" @@ -227,6 +232,12 @@ struct BaseOperator #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual std::string GetTypeString() const { return ""; } + +#ifdef CK_EXPERIMENTAL_BUILDER + // Return a description object for this operator, or nullptr if not supported. + virtual std::unique_ptr describe() const { return nullptr; } +#endif + virtual std::string GetInstanceString() const { return ""; } virtual std::string GetTypeIdName() const { return typeid(*this).name(); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index be5c6eba40..42ad21dafe 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -25,6 +25,7 @@ #include "ck/host_utility/kernel_launch.hpp" #ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #endif @@ -1240,6 +1241,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle "for the given template parameters."); return ck_tile::reflect::instance_string(); } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } #endif size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 347ea25e62..b5ca71d1fa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -25,6 +25,7 @@ #include "ck/host_utility/io.hpp" #ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #endif @@ -1064,6 +1065,11 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK "for the given template parameters."); return ck_tile::reflect::instance_string(); } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } #endif }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index a9b0975050..5ed8da8d1b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -29,6 +29,7 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" #ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/conv_describe.hpp" #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #endif @@ -2080,6 +2081,21 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle "for the given template parameters."); return ck_tile::reflect::instance_string(); } + + std::unique_ptr describe() const override + { + static_assert(ck_tile::reflect::conv::HasConvTraits, + "ConvTraits specialization not found for this device operation. " + "If you modified the template parameters of this class, ensure that " + "the corresponding ConvTraits specialization in " + "ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that " + "InstanceTraits in " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp " + "provides all required members for ConvTraits to work."); + return std::make_unique( + ck_tile::reflect::describe()); + } #endif size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 380f94426f..e69a9caa9c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -29,6 +29,7 @@ #include "ck/host_utility/flush_cache.hpp" #include "ck/host_utility/io.hpp" #ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/conv_describe.hpp" #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" #endif @@ -2103,6 +2104,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 "for the given template parameters."); return ck_tile::reflect::instance_string(); } + + std::unique_ptr describe() const override + { + return std::make_unique( + ck_tile::reflect::describe()); + } #endif size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index a3391c55e8..32e444fe1f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -25,6 +25,7 @@ #include "ck/host_utility/io.hpp" #ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" #endif @@ -1019,6 +1020,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle "for the given template parameters."); return ck_tile::reflect::instance_string(); } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } #endif }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index b51b78d6b9..b21af2abb0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -25,6 +25,7 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" #ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/conv_describe.hpp" #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" #endif @@ -1238,6 +1239,22 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor "for the given template parameters."); return ck_tile::reflect::instance_string(); } + + std::unique_ptr describe() const override + { + static_assert( + ck_tile::reflect::conv::HasConvTraits, + "ConvTraits specialization not found for this device operation. " + "If you modified the template parameters of this class, ensure that " + "the corresponding ConvTraits specialization in " + "ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that " + "InstanceTraits in " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp " + "provides all required members for ConvTraits to work."); + return std::make_unique( + ck_tile::reflect::describe()); + } #endif };