Add describe() method to device ops for runtime introspection (#3375)

Introduces a polymorphic describe() method to BaseOperator that enables runtime introspection of kernel configurations through a unified interface.

Key changes:

* Add virtual describe() method to BaseOperator returning Description objects
* Implement describe() in 6 device operation classes (conv fwd/bwd variants)
* Create conv_describe.hpp with factory function for ConvDescription
* Extract type definitions to conv_types.hpp to resolve circular dependencies
* Add InstanceStringDescription for kernels without full ConvDescription support

Other Improvements:

* Update tests to use describe() instead of GetInstanceString()
* Remove circular dependency include from conv_traits.hpp
* Add ODD_C to ConvFwdSpecialization enum and fix OddC mapping
* Replace silent fallback in conv_layout() with compile-time error

This provides a foundation for runtime kernel introspection and better tooling support for analyzing and debugging kernel configurations.
This commit is contained in:
John Shumway
2025-12-14 12:49:12 -08:00
committed by GitHub
parent 21f06aa47d
commit 9ac51aa0f4
22 changed files with 549 additions and 211 deletions

View File

@@ -153,6 +153,7 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC;
default: throw "Unknown ConvFwdSpecialization";
}
}

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// @file
/// @brief Implementation of the describe() function template for convolution kernels
#pragma once
#include "ck_tile/builder/reflect/conv_description.hpp"
#include "ck_tile/builder/reflect/conv_traits.hpp"
namespace ck_tile::reflect {
/// @brief Factory function to create ConvDescription from a convolution instance type
/// @tparam Instance The convolution instance type (must have ConvTraits)
/// @return A ConvDescription object populated with the instance's configuration details
template <conv::HasConvTraits Instance>
conv::ConvDescription describe()
{
using Traits = conv::ConvTraits<Instance>;
return conv::ConvDescription(
conv::ConvSignatureInfo{
.spatial_dim = Traits::spatial_dim,
.direction = Traits::direction,
.input_layout = Traits::layout[0],
.weight_layout = Traits::layout[1],
.output_layout = Traits::layout[2],
.data_type = Traits::data_type,
.input_element_op = Traits::input_element_op,
.weight_element_op = Traits::weight_element_op,
.output_element_op = Traits::output_element_op,
},
conv::GemmAlgorithmInfo{
.thread_block_size = Traits::thread_block_size,
.tile_dims = Traits::tile_dims,
.warp_gemm = Traits::warp_gemm,
.a_tile_transfer = Traits::a_tile_transfer,
.b_tile_transfer = Traits::b_tile_transfer,
.c_tile_transfer = Traits::c_tile_transfer,
.pipeline_version = Traits::pipeline_version,
.pipeline_scheduler = Traits::pipeline_scheduler,
.conv_specialization = Traits::conv_specialization,
.padding = Traits::gemm_padding,
},
[]() { return reflect::instance_string<Instance>(); });
}
} // namespace ck_tile::reflect

View File

@@ -25,7 +25,7 @@
#include <functional>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/conv_types.hpp>
#include <ck_tile/builder/reflect/description.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/tree_formatter.hpp>
@@ -249,41 +249,7 @@ class ConvDescription : public Description
GemmAlgorithmInfo algorithm_;
std::function<std::string()> instance_string_getter_;
};
} // namespace conv
/// @brief Factory function to create ConvDescription from a convolution instance type
/// @tparam Instance The convolution instance type (must have ConvTraits specialization)
/// @return A ConvDescription object populated with the instance's configuration details
template <conv::HasConvTraits Instance>
conv::ConvDescription describe()
{
using Traits = conv::ConvTraits<Instance>;
return conv::ConvDescription(
conv::ConvSignatureInfo{
.spatial_dim = Traits::spatial_dim,
.direction = Traits::direction,
.input_layout = Traits::layout[0],
.weight_layout = Traits::layout[1],
.output_layout = Traits::layout[2],
.data_type = Traits::data_type,
.input_element_op = Traits::input_element_op,
.weight_element_op = Traits::weight_element_op,
.output_element_op = Traits::output_element_op,
},
conv::GemmAlgorithmInfo{
.thread_block_size = Traits::thread_block_size,
.tile_dims = Traits::tile_dims,
.warp_gemm = Traits::warp_gemm,
.a_tile_transfer = Traits::a_tile_transfer,
.b_tile_transfer = Traits::b_tile_transfer,
.c_tile_transfer = Traits::c_tile_transfer,
.pipeline_version = Traits::pipeline_version,
.pipeline_scheduler = Traits::pipeline_scheduler,
.conv_specialization = Traits::conv_specialization,
.padding = Traits::gemm_padding,
},
[]() { return reflect::instance_string<Instance>(); });
}
} // namespace ck_tile::reflect

View File

@@ -10,8 +10,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/pipeline_enum.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck_tile/builder/conv_builder.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/reflect/conv_types.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
#include "ck_tile/builder/types.hpp"
@@ -161,103 +161,19 @@ constexpr auto convert_pipeline_scheduler()
}
}
/// @brief Helper structures for organizing trait data with domain-specific naming
/// @brief Data tile dimensions processed by a workgroup.
/// @details This struct defines the M, N, and K dimensions of the data tile
/// that a single workgroup (thread block) is responsible for processing in the
/// underlying GEMM computation.
struct DataTileInfo
{
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
};
/// @brief Dimensions for an input data tile transfer.
/// @details Defines the shape of the input tile (A or B matrix) as it is
/// transferred from global memory to LDS. The tile is conceptually divided
/// into k0 and k1 dimensions.
struct InputTileTransferDimensions
{
int k0; ///< The outer dimension of K, where K = k0 * k1.
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
///< memory.
};
/// @brief Parameters governing the transfer of an input tile.
/// @details This struct holds configuration details for how an input tile is
/// loaded from global memory into LDS, including thread clustering, memory
/// access patterns, and vectorization settings.
struct InputTileTransferParams
{
int k1; ///< The inner K dimension size, often matching the vectorization width.
std::array<int, 3>
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
///< many threads are arranged on each axis.
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
///< input tensor dimensions.
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
///< dimension to read first).
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
///< (the contiguous dimension).
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
///< elements accessed per thread per instruction.
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
///< dimension.
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
///< conflicts.
};
/// @brief Complete information for an input tile transfer.
/// @details Combines the dimensional information and transfer parameters for
/// a full description of an input tile's journey from global memory to LDS.
struct InputTileTransferInfo
{
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
};
/// @brief Parameters for the warp-level GEMM computation.
/// @details Defines the configuration of the GEMM operation performed by each
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
struct WarpGemmParams
{
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
///< wavefront (MXdlPerWave).
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
///< wavefront (NXdlPerWave).
};
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
/// @details Configures how many MFMA instruction results are processed per
/// wave in each iteration of the CShuffle routine.
struct WarpShuffleParams
{
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
///< per shuffle iteration.
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
///< per shuffle iteration.
};
/// @brief Information for the output tile transfer (CShuffle).
/// @details Describes how the final computed tile (C matrix) is written out from
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
struct OutputTileTransferInfo
{
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
///< data into the output tensor.
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
///< output tensor.
};
// Helper metafunctions to derive signature information from Instance types
/// @brief Helper function to report unsupported convolution direction with a clear error message.
template <typename Instance>
consteval void report_unsupported_conv_direction_error()
{
throw "Unsupported convolution direction detected!\n"
"The kernel instance does not have a recognized convolution specialization.\n"
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
"kConvBwdWeightSpecialization.\n"
"Please verify that your kernel instance is properly configured.";
}
/// @brief Derives the convolution direction from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
@@ -273,7 +189,10 @@ constexpr builder::ConvDirection conv_direction()
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
return builder::ConvDirection::BACKWARD_WEIGHT;
else
return builder::ConvDirection::FORWARD; // Default fallback
{
report_unsupported_conv_direction_error<Instance>();
return builder::ConvDirection::FORWARD; // Unreachable
}
}
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
@@ -296,6 +215,7 @@ constexpr auto conv_spec()
case Filter1x1Pad0: return FILTER_1X1_PAD0;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
case Filter3x3: return FILTER_3x3;
case OddC: return ODD_C;
}
}
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
@@ -334,6 +254,20 @@ template <typename A,
inline constexpr bool layouts_are =
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
/// @brief Helper function to report unsupported layout combinations with a clear error message.
/// @details This consteval function is designed to fail at compile time with a descriptive
/// error message when an unsupported layout combination is encountered.
template <typename A, typename B, typename E, int SpatialDim>
consteval void report_unsupported_layout_error()
{
// This will produce a compile-time error with the exception message
throw "Unsupported convolution layout combination detected!\n"
"The combination of ALayout, BLayout, and ELayout template parameters\n"
"is not recognized for the given spatial dimension.\n"
"Please verify that your convolution instance uses a supported layout configuration.\n"
"Check the conv_layout() function for the list of supported layout combinations.";
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
@@ -358,6 +292,8 @@ constexpr auto conv_layout()
case 1:
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
return layouts(NWGC, GKXC, NWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
@@ -368,8 +304,12 @@ constexpr auto conv_layout()
case 2:
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
return layouts(NGCHW, GKYXC, NGKHW);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
@@ -378,6 +318,8 @@ constexpr auto conv_layout()
case 3:
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
return layouts(NDHWGC, GKZYXC, NDHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
@@ -386,11 +328,31 @@ constexpr auto conv_layout()
return layouts(NGCDHW, GKCZYX, NGKDHW);
break;
}
// If we reach here, the layout combination is not supported
// Call consteval function to trigger a compile-time error with a clear message
report_unsupported_layout_error<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
// This return is unreachable but needed to satisfy the compiler
return layouts(GNHWC, GKYXC, GNHWK);
}
/// @brief Helper function to report unsupported data type with a clear error message.
template <typename ADataType>
consteval void report_unsupported_data_type_error()
{
throw "Unsupported data type detected!\n"
"The ADataType is not recognized.\n"
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
"(BF8), "
"int8_t (I8), ck::Tuple<int8_t, int8_t> (I8_I8), uint8_t (U8).\n"
"Please verify that your kernel instance uses a supported data type.";
}
/// @brief Derives the data type from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8).
template <typename Instance>
constexpr builder::DataType conv_data_type()
requires HasDataTypes<InstanceTraits<Instance>>
@@ -401,18 +363,50 @@ constexpr builder::DataType conv_data_type()
if constexpr(std::is_same_v<ADataType, ck::half_t>)
return FP16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
return FP16_FP16;
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
return BF16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
return BF16_BF16;
else if constexpr(std::is_same_v<ADataType, float>)
return FP32;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
return FP32_FP32;
else if constexpr(std::is_same_v<ADataType, double>)
return FP64;
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
return FP8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, int8_t>)
return I8;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
return I8_I8;
else if constexpr(std::is_same_v<ADataType, uint8_t>)
return U8;
else
return FP32; // Default fallback
{
report_unsupported_data_type_error<ADataType>();
return FP32; // Unreachable
}
}
/// @brief Helper function to report unsupported elementwise operation with a clear error message.
template <typename ElementwiseOp>
consteval void report_unsupported_elementwise_op_error()
{
throw "Unsupported elementwise operation detected!\n"
"The elementwise operation type is not recognized.\n"
"Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, "
"BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, "
"ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, "
"UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, "
"Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, "
"UnaryConvert.\n"
"Please verify that your kernel instance uses a supported elementwise operation.";
}
/// @brief Derives the elementwise operation from op type.
@@ -424,16 +418,83 @@ constexpr builder::ElementwiseOperation elementwise_op()
using enum builder::ElementwiseOperation;
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
if constexpr(detail::case_insensitive_equal(name, "AddClamp"))
return ADD_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd"))
return ADD_RELU_ADD;
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
return BIAS_BNORM_CLAMP;
if constexpr(detail::case_insensitive_equal(name, "Clamp"))
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
return BILINEAR;
else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp"))
return BIAS_BNORM_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
return CLAMP;
if constexpr(detail::case_insensitive_equal(name, "Scale"))
else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale"))
return CONV_INVSCALE;
else if constexpr(detail::case_insensitive_equal(name, "ConvScale"))
return CONV_SCALE;
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd"))
return CONV_SCALE_ADD;
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu"))
return CONV_SCALE_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
return SCALE;
if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd"))
return SCALE_ADD;
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
return PASS_THROUGH;
if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
return SCALEADD_SCALEADD_RELU;
else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp"))
return DYNAMIC_UNARY_OP;
else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp"))
return UNARY_COMBINED_OP;
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp"))
return ACTIVATION_MUL2_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp"))
return ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp"))
return ADD_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp"))
return ADD_ACTIVATION_MUL2_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp"))
return ADD_MUL_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp"))
return ADD_MUL2_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert"))
return UNARY_CONVERT;
else if constexpr(detail::case_insensitive_equal(name, "Logistic"))
return LOGISTIC;
else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu"))
return CLIPPED_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Swish"))
return SWISH;
else if constexpr(detail::case_insensitive_equal(name, "Elu"))
return ELU;
else if constexpr(detail::case_insensitive_equal(name, "Power"))
return POWER;
else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu"))
return LEAKY_RELU;
else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs"))
return UNARY_ABS;
else if constexpr(detail::case_insensitive_equal(name, "Relu"))
return RELU;
else if constexpr(detail::case_insensitive_equal(name, "SoftRelu"))
return SOFT_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Sigmoid"))
return SIGMOID;
else if constexpr(detail::case_insensitive_equal(name, "TanH"))
return TANH;
else if constexpr(detail::case_insensitive_equal(name, "Gelu"))
return GELU;
else if constexpr(detail::case_insensitive_equal(name, "Silu"))
return SILU;
else
{
report_unsupported_elementwise_op_error<ElementwiseOp>();
return PASS_THROUGH; // Unreachable
}
}
/// @brief Derives a gemm padding from a kernel instance type.
@@ -606,45 +667,4 @@ struct ConvTraits<Instance>
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
};
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
/// @details This specialization provides backward compatibility for reflecting
/// on kernels defined via the `ConvBuilder` interface. It works by first
/// creating the `Instance` via the builder, and then delegating
/// all trait extraction to the `ConvTraits<Instance>` specialization.
template <builder::ConvSignatureDescriptor auto SIGNATURE,
builder::ConvAlgorithmDescriptor auto ALGORITHM,
builder::StringLiteral VERSION>
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
{
using Instance = typename builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>::Instance;
// Delegate to Instance-based ConvTraits
using InstanceConvTraits = ConvTraits<Instance>;
// Forward all members from Instance-based traits
static constexpr int spatial_dim = InstanceConvTraits::spatial_dim;
static constexpr builder::ConvDirection direction = InstanceConvTraits::direction;
static constexpr auto layout = InstanceConvTraits::layout;
static constexpr builder::DataType data_type = InstanceConvTraits::data_type;
static constexpr builder::ElementwiseOperation input_element_op =
InstanceConvTraits::input_element_op;
static constexpr builder::ElementwiseOperation weight_element_op =
InstanceConvTraits::weight_element_op;
static constexpr builder::ElementwiseOperation output_element_op =
InstanceConvTraits::output_element_op;
static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding;
static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization;
static constexpr int thread_block_size = InstanceConvTraits::thread_block_size;
static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims;
static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer;
static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer;
static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm;
static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer;
static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version;
static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler;
};
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,109 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// @file
/// @brief Type definitions for convolution reflection
///
/// This file contains the type definitions used by both conv_traits.hpp and conv_description.hpp
/// to avoid circular dependencies.
#pragma once
#include <array>
namespace ck_tile::reflect::conv {
/// @brief Data tile dimensions processed by a workgroup.
/// @details This struct defines the M, N, and K dimensions of the data tile
/// that a single workgroup (thread block) is responsible for processing in the
/// underlying GEMM computation.
struct DataTileInfo
{
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
};
/// @brief Dimensions for an input data tile transfer.
/// @details Defines the shape of the input tile (A or B matrix) as it is
/// transferred from global memory to LDS. The tile is conceptually divided
/// into k0 and k1 dimensions.
struct InputTileTransferDimensions
{
int k0; ///< The outer dimension of K, where K = k0 * k1.
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
///< memory.
};
/// @brief Parameters governing the transfer of an input tile.
/// @details This struct holds configuration details for how an input tile is
/// loaded from global memory into LDS, including thread clustering, memory
/// access patterns, and vectorization settings.
struct InputTileTransferParams
{
int k1; ///< The inner K dimension size, often matching the vectorization width.
std::array<int, 3>
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
///< many threads are arranged on each axis.
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
///< input tensor dimensions.
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
///< dimension to read first).
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
///< (the contiguous dimension).
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
///< elements accessed per thread per instruction.
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
///< dimension.
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
///< conflicts.
};
/// @brief Complete information for an input tile transfer.
/// @details Combines the dimensional information and transfer parameters for
/// a full description of an input tile's journey from global memory to LDS.
struct InputTileTransferInfo
{
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
};
/// @brief Parameters for the warp-level GEMM computation.
/// @details Defines the configuration of the GEMM operation performed by each
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
struct WarpGemmParams
{
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
///< wavefront (MXdlPerWave).
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
///< wavefront (NXdlPerWave).
};
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
/// @details Configures how many MFMA instruction results are processed per
/// wave in each iteration of the CShuffle routine.
struct WarpShuffleParams
{
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
///< per shuffle iteration.
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
///< per shuffle iteration.
};
/// @brief Information for the output tile transfer (CShuffle).
/// @details Describes how the final computed tile (C matrix) is written out from
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
struct OutputTileTransferInfo
{
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
///< data into the output tensor.
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
///< output tensor.
};
} // namespace ck_tile::reflect::conv

View File

@@ -20,6 +20,11 @@ namespace ck_tile::reflect {
class Description
{
public:
Description() = default;
Description(const Description&) = default;
Description(Description&&) = default;
Description& operator=(const Description&) = default;
Description& operator=(Description&&) = default;
/// @brief Virtual destructor for proper cleanup of derived classes
virtual ~Description() = default;
@@ -36,4 +41,30 @@ class Description
virtual std::string instance_string() const = 0;
};
/// @brief A specialized Description that only supports instance_string()
/// This is a helper class for kernels that don't yet have full ConvDescription support.
/// The brief() and detailed() methods return "not supported" placeholders.
class InstanceStringDescription : public Description
{
public:
/// @brief Construct with an instance string
/// @param instance The instance string to store
explicit InstanceStringDescription(std::string instance) : instance_(std::move(instance)) {}
/// @brief Returns "not supported" as brief descriptions are not implemented
/// @return A placeholder string indicating the feature is not supported
std::string brief() const override { return "not supported"; }
/// @brief Returns "not supported" as detailed descriptions are not implemented
/// @return A placeholder string indicating the feature is not supported
std::string detailed() const override { return "not supported"; }
/// @brief Returns the stored instance string
/// @return The instance string provided during construction
std::string instance_string() const override { return instance_; }
private:
std::string instance_; ///< The stored instance string
};
} // namespace ck_tile::reflect

View File

@@ -11,15 +11,22 @@
namespace ck_tile::builder {
// TODO: Handle tuple types and FP8/BF8 properly
enum class DataType
{
UNDEFINED_DATA_TYPE = 0,
FP32,
FP32_FP32,
FP16,
FP16_FP16,
BF16,
BF16_BF16,
FP8,
BF8,
FP64,
INT32,
I8,
I8_I8,
U8
};
@@ -102,13 +109,44 @@ enum class ConvDirection
};
// Fused element-wise operations.
// TODO: Generalize design rather than enumerating all possible ops.
enum class ElementwiseOperation
{
ADD_CLAMP,
ADD_RELU_ADD,
ACTIVATION_MUL2_CLAMP,
ACTIVATION_MUL_CLAMP,
ADD_ACTIVATION_MUL_CLAMP,
ADD_ACTIVATION_MUL2_CLAMP,
ADD_MUL_ACTIVATION_MUL_CLAMP,
ADD_MUL2_ACTIVATION_MUL_CLAMP,
BIAS_BNORM_CLAMP,
BILINEAR,
SCALE,
SCALE_ADD,
CLAMP,
CONV_INVSCALE,
CONV_SCALE,
CONV_SCALE_ADD,
CONV_SCALE_RELU,
PASS_THROUGH,
SCALEADD_SCALEADD_RELU
SCALEADD_SCALEADD_RELU,
DYNAMIC_UNARY_OP,
UNARY_COMBINED_OP,
UNARY_CONVERT,
LOGISTIC,
CLIPPED_RELU,
SWISH,
ELU,
POWER,
LEAKY_RELU,
UNARY_ABS,
RELU,
SOFT_RELU,
SIGMOID,
TANH,
GELU,
SILU
};
// Enums for pipeline versions & schedulers
@@ -160,7 +198,8 @@ enum class ConvFwdSpecialization
DEFAULT,
FILTER_1X1_PAD0,
FILTER_1X1_STRIDE1_PAD0,
FILTER_3x3
FILTER_3x3,
ODD_C
};
// Enums for the backward data convolution specialization.
@@ -219,11 +258,17 @@ inline std::string_view toString(DataType dt)
switch(dt)
{
case FP16: return "FP16";
case FP16_FP16: return "FP16_FP16";
case FP32: return "FP32";
case FP32_FP32: return "FP32_FP32";
case BF16: return "BF16";
case BF16_BF16: return "BF16_BF16";
case FP8: return "FP8";
case BF8: return "BF8";
case FP64: return "FP64";
case INT32: return "INT32";
case I8: return "I8";
case I8_I8: return "I8_I8";
case U8: return "U8";
case UNDEFINED_DATA_TYPE: return "UNDEFINED_DATA_TYPE";
default: return "Unknown";
@@ -247,11 +292,41 @@ inline std::string_view toString(ElementwiseOperation op)
using enum ElementwiseOperation;
switch(op)
{
case ADD_CLAMP: return "ADD_CLAMP";
case ADD_RELU_ADD: return "ADD_RELU_ADD";
case ACTIVATION_MUL2_CLAMP: return "ACTIVATION_MUL2_CLAMP";
case ACTIVATION_MUL_CLAMP: return "ACTIVATION_MUL_CLAMP";
case ADD_ACTIVATION_MUL_CLAMP: return "ADD_ACTIVATION_MUL_CLAMP";
case ADD_ACTIVATION_MUL2_CLAMP: return "ADD_ACTIVATION_MUL2_CLAMP";
case ADD_MUL_ACTIVATION_MUL_CLAMP: return "ADD_MUL_ACTIVATION_MUL_CLAMP";
case ADD_MUL2_ACTIVATION_MUL_CLAMP: return "ADD_MUL2_ACTIVATION_MUL_CLAMP";
case BIAS_BNORM_CLAMP: return "BIAS_BNORM_CLAMP";
case BILINEAR: return "BILINEAR";
case CLAMP: return "CLAMP";
case SCALE: return "SCALE";
case SCALE_ADD: return "SCALE_ADD";
case CONV_INVSCALE: return "CONV_INVSCALE";
case CONV_SCALE: return "CONV_SCALE";
case CONV_SCALE_ADD: return "CONV_SCALE_ADD";
case CONV_SCALE_RELU: return "CONV_SCALE_RELU";
case PASS_THROUGH: return "PASS_THROUGH";
case BIAS_BNORM_CLAMP: return "BIAS_BNORM_CLAMP";
case SCALEADD_SCALEADD_RELU: return "SCALEADD_SCALEADD_RELU";
case DYNAMIC_UNARY_OP: return "DYNAMIC_UNARY_OP";
case UNARY_COMBINED_OP: return "UNARY_COMBINED_OP";
case UNARY_CONVERT: return "UNARY_CONVERT";
case LOGISTIC: return "LOGISTIC";
case CLIPPED_RELU: return "CLIPPED_RELU";
case SWISH: return "SWISH";
case ELU: return "ELU";
case POWER: return "POWER";
case LEAKY_RELU: return "LEAKY_RELU";
case UNARY_ABS: return "UNARY_ABS";
case RELU: return "RELU";
case SOFT_RELU: return "SOFT_RELU";
case SIGMOID: return "SIGMOID";
case TANH: return "TANH";
case GELU: return "GELU";
case SILU: return "SILU";
default: return "Unknown";
}
}
@@ -305,6 +380,7 @@ inline std::string_view toString(ConvFwdSpecialization spec)
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
case FILTER_3x3: return "FILTER_3x3";
case ODD_C: return "ODD_C";
default: return "Unknown";
}
}

View File

@@ -5,6 +5,7 @@
#include <gmock/gmock.h>
#include <concepts>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>

View File

@@ -4,8 +4,9 @@
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include "ck_tile/builder/conv_builder.hpp"
#include "ck_tile/builder/reflect/conv_description.hpp"
#include "ck_tile/builder/reflect/conv_describe.hpp"
#include "testing_utils.hpp"
#include "impl/conv_signature_types.hpp"
#include "impl/conv_algorithm_types.hpp"

View File

@@ -72,14 +72,16 @@ std::string expected_str = "DeviceGroupedConvBwdWeight_Xdl_CShuffle"
",1" // MaxTransposeTransferSrcScalarPerVector
",1>"; // MaxTransposeTransferDstScalarPerVector
// Test GetInstanceString through base class pointer for backward weight XDL variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForBwdWeightGrpConvXdl)
// Test describe() through base class pointer for backward weight XDL variant
TEST(InstanceString, DescribeReturnsCorrectValueForBwdWeightGrpConvXdl)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
// TODO: Add DescriptionReturnsCorrectValueForBwdWeightGrpConvXdl test once ckr::describe supports

View File

@@ -2,10 +2,11 @@
// SPDX-License-Identifier: MIT
#include "gtest/gtest.h"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/conv_description.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
#include <ck_tile/builder/reflect/conv_describe.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
namespace {
@@ -77,14 +78,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
",Default" // LoopScheduler
",1>"; // NumGroupsToMerge
// Test GetInstanceString through base class pointer for standard XDL variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConv)
// Test describe() through base class pointer for standard XDL variant
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConv)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConv)

View File

@@ -71,14 +71,16 @@ std::string expected_str = "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
",5" // CThreadTransferSrcDstVectorDim
",1>"; // CThreadTransferDstScalarPerVector
// Test GetInstanceString through base class pointer for DL variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvDl)
// Test describe() through base class pointer for DL variant
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvDl)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
// TODO: Add DescriptionReturnsCorrectValueForFwdGrpConvDl test once ckr::describe supports DL

View File

@@ -2,10 +2,11 @@
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/conv_describe.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
namespace {
@@ -76,14 +77,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Ten
",fp16" // BComputeDataType
",Default>"; // LoopScheduler
// Test GetInstanceString through base class pointer for large tensor variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvLargeTensor)
// Test describe() through base class pointer for large tensor variant
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvLargeTensor)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConvLargeTensor)

View File

@@ -3,6 +3,7 @@
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/conv_describe.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp>
@@ -78,14 +79,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
",fp16" // BComputeDataType
",false>"; // DirectLoad
// Test GetInstanceString through base class pointer for V3 variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvV3)
// Test describe() through base class pointer for V3 variant
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvV3)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConvV3)

View File

@@ -76,14 +76,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
",Default" // LoopSched
",v1>"; // PipelineVer
// Test GetInstanceString through base class pointer for WMMA variant
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvWmma)
// Test describe() through base class pointer for WMMA variant
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvWmma)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
// TODO: Add DescriptionReturnsCorrectValueForFwdGrpConvWmma test once ckr::describe supports WMMA

View File

@@ -8,8 +8,13 @@
#include <sstream>
#include <regex>
#include <optional>
#include <memory>
#include "ck/stream_config.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#endif
#endif
#include "ck/utility/get_id.hpp"
@@ -227,6 +232,12 @@ struct BaseOperator
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; }
#ifdef CK_EXPERIMENTAL_BUILDER
// Return a description object for this operator, or nullptr if not supported.
virtual std::unique_ptr<ck_tile::reflect::Description> describe() const { return nullptr; }
#endif
virtual std::string GetInstanceString() const { return ""; }
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#endif
@@ -1240,6 +1241,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#endif
@@ -1064,6 +1065,11 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};

View File

@@ -29,6 +29,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/conv_describe.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#endif
@@ -2080,6 +2081,21 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
static_assert(ck_tile::reflect::conv::HasConvTraits<DeviceOp>,
"ConvTraits specialization not found for this device operation. "
"If you modified the template parameters of this class, ensure that "
"the corresponding ConvTraits specialization in "
"ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that "
"InstanceTraits in "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp "
"provides all required members for ConvTraits to work.");
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
ck_tile::reflect::describe<DeviceOp>());
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override

View File

@@ -29,6 +29,7 @@
#include "ck/host_utility/flush_cache.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/conv_describe.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#endif
@@ -2103,6 +2104,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
ck_tile::reflect::describe<DeviceOp>());
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#endif
@@ -1019,6 +1020,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/conv_describe.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#endif
@@ -1238,6 +1239,22 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
static_assert(
ck_tile::reflect::conv::HasConvTraits<DeviceOp>,
"ConvTraits specialization not found for this device operation. "
"If you modified the template parameters of this class, ensure that "
"the corresponding ConvTraits specialization in "
"ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that "
"InstanceTraits in "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp "
"provides all required members for ConvTraits to work.");
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
ck_tile::reflect::describe<DeviceOp>());
}
#endif
};