mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add describe() method to device ops for runtime introspection (#3375)
Introduces a polymorphic describe() method to BaseOperator that enables runtime introspection of kernel configurations through a unified interface. Key changes: * Add virtual describe() method to BaseOperator returning Description objects * Implement describe() in 6 device operation classes (conv fwd/bwd variants) * Create conv_describe.hpp with factory function for ConvDescription * Extract type definitions to conv_types.hpp to resolve circular dependencies * Add InstanceStringDescription for kernels without full ConvDescription support Other Improvements: * Update tests to use describe() instead of GetInstanceString() * Remove circular dependency include from conv_traits.hpp * Add ODD_C to ConvFwdSpecialization enum and fix OddC mapping * Replace silent fallback in conv_layout() with compile-time error This provides a foundation for runtime kernel introspection and better tooling support for analyzing and debugging kernel configurations.
This commit is contained in:
@@ -153,6 +153,7 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC
|
||||
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
|
||||
case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC;
|
||||
default: throw "Unknown ConvFwdSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/// @file
|
||||
/// @brief Implementation of the describe() function template for convolution kernels
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_description.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
/// @brief Factory function to create ConvDescription from a convolution instance type
|
||||
/// @tparam Instance The convolution instance type (must have ConvTraits)
|
||||
/// @return A ConvDescription object populated with the instance's configuration details
|
||||
template <conv::HasConvTraits Instance>
|
||||
conv::ConvDescription describe()
|
||||
{
|
||||
using Traits = conv::ConvTraits<Instance>;
|
||||
|
||||
return conv::ConvDescription(
|
||||
conv::ConvSignatureInfo{
|
||||
.spatial_dim = Traits::spatial_dim,
|
||||
.direction = Traits::direction,
|
||||
.input_layout = Traits::layout[0],
|
||||
.weight_layout = Traits::layout[1],
|
||||
.output_layout = Traits::layout[2],
|
||||
.data_type = Traits::data_type,
|
||||
.input_element_op = Traits::input_element_op,
|
||||
.weight_element_op = Traits::weight_element_op,
|
||||
.output_element_op = Traits::output_element_op,
|
||||
},
|
||||
conv::GemmAlgorithmInfo{
|
||||
.thread_block_size = Traits::thread_block_size,
|
||||
.tile_dims = Traits::tile_dims,
|
||||
.warp_gemm = Traits::warp_gemm,
|
||||
.a_tile_transfer = Traits::a_tile_transfer,
|
||||
.b_tile_transfer = Traits::b_tile_transfer,
|
||||
.c_tile_transfer = Traits::c_tile_transfer,
|
||||
.pipeline_version = Traits::pipeline_version,
|
||||
.pipeline_scheduler = Traits::pipeline_scheduler,
|
||||
.conv_specialization = Traits::conv_specialization,
|
||||
.padding = Traits::gemm_padding,
|
||||
},
|
||||
[]() { return reflect::instance_string<Instance>(); });
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -25,7 +25,7 @@
|
||||
#include <functional>
|
||||
|
||||
#include <ck_tile/builder/conv_signature_concepts.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_types.hpp>
|
||||
#include <ck_tile/builder/reflect/description.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/tree_formatter.hpp>
|
||||
@@ -249,41 +249,7 @@ class ConvDescription : public Description
|
||||
GemmAlgorithmInfo algorithm_;
|
||||
std::function<std::string()> instance_string_getter_;
|
||||
};
|
||||
|
||||
} // namespace conv
|
||||
|
||||
/// @brief Factory function to create ConvDescription from a convolution instance type
|
||||
/// @tparam Instance The convolution instance type (must have ConvTraits specialization)
|
||||
/// @return A ConvDescription object populated with the instance's configuration details
|
||||
template <conv::HasConvTraits Instance>
|
||||
conv::ConvDescription describe()
|
||||
{
|
||||
using Traits = conv::ConvTraits<Instance>;
|
||||
|
||||
return conv::ConvDescription(
|
||||
conv::ConvSignatureInfo{
|
||||
.spatial_dim = Traits::spatial_dim,
|
||||
.direction = Traits::direction,
|
||||
.input_layout = Traits::layout[0],
|
||||
.weight_layout = Traits::layout[1],
|
||||
.output_layout = Traits::layout[2],
|
||||
.data_type = Traits::data_type,
|
||||
.input_element_op = Traits::input_element_op,
|
||||
.weight_element_op = Traits::weight_element_op,
|
||||
.output_element_op = Traits::output_element_op,
|
||||
},
|
||||
conv::GemmAlgorithmInfo{
|
||||
.thread_block_size = Traits::thread_block_size,
|
||||
.tile_dims = Traits::tile_dims,
|
||||
.warp_gemm = Traits::warp_gemm,
|
||||
.a_tile_transfer = Traits::a_tile_transfer,
|
||||
.b_tile_transfer = Traits::b_tile_transfer,
|
||||
.c_tile_transfer = Traits::c_tile_transfer,
|
||||
.pipeline_version = Traits::pipeline_version,
|
||||
.pipeline_scheduler = Traits::pipeline_scheduler,
|
||||
.conv_specialization = Traits::conv_specialization,
|
||||
.padding = Traits::gemm_padding,
|
||||
},
|
||||
[]() { return reflect::instance_string<Instance>(); });
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
|
||||
@@ -10,8 +10,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/pipeline_enum.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_types.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
@@ -161,103 +161,19 @@ constexpr auto convert_pipeline_scheduler()
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Helper structures for organizing trait data with domain-specific naming
|
||||
|
||||
/// @brief Data tile dimensions processed by a workgroup.
|
||||
/// @details This struct defines the M, N, and K dimensions of the data tile
|
||||
/// that a single workgroup (thread block) is responsible for processing in the
|
||||
/// underlying GEMM computation.
|
||||
struct DataTileInfo
|
||||
{
|
||||
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
|
||||
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
|
||||
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
|
||||
};
|
||||
|
||||
/// @brief Dimensions for an input data tile transfer.
|
||||
/// @details Defines the shape of the input tile (A or B matrix) as it is
|
||||
/// transferred from global memory to LDS. The tile is conceptually divided
|
||||
/// into k0 and k1 dimensions.
|
||||
struct InputTileTransferDimensions
|
||||
{
|
||||
int k0; ///< The outer dimension of K, where K = k0 * k1.
|
||||
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
|
||||
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
|
||||
///< memory.
|
||||
};
|
||||
|
||||
/// @brief Parameters governing the transfer of an input tile.
|
||||
/// @details This struct holds configuration details for how an input tile is
|
||||
/// loaded from global memory into LDS, including thread clustering, memory
|
||||
/// access patterns, and vectorization settings.
|
||||
struct InputTileTransferParams
|
||||
{
|
||||
int k1; ///< The inner K dimension size, often matching the vectorization width.
|
||||
std::array<int, 3>
|
||||
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
|
||||
///< many threads are arranged on each axis.
|
||||
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
|
||||
///< input tensor dimensions.
|
||||
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
|
||||
///< dimension to read first).
|
||||
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
|
||||
///< (the contiguous dimension).
|
||||
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
|
||||
///< elements accessed per thread per instruction.
|
||||
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
|
||||
///< dimension.
|
||||
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
|
||||
///< conflicts.
|
||||
};
|
||||
|
||||
/// @brief Complete information for an input tile transfer.
|
||||
/// @details Combines the dimensional information and transfer parameters for
|
||||
/// a full description of an input tile's journey from global memory to LDS.
|
||||
struct InputTileTransferInfo
|
||||
{
|
||||
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
|
||||
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
|
||||
};
|
||||
|
||||
/// @brief Parameters for the warp-level GEMM computation.
|
||||
/// @details Defines the configuration of the GEMM operation performed by each
|
||||
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
|
||||
struct WarpGemmParams
|
||||
{
|
||||
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
|
||||
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
|
||||
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
|
||||
///< wavefront (MXdlPerWave).
|
||||
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
|
||||
///< wavefront (NXdlPerWave).
|
||||
};
|
||||
|
||||
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
|
||||
/// @details Configures how many MFMA instruction results are processed per
|
||||
/// wave in each iteration of the CShuffle routine.
|
||||
struct WarpShuffleParams
|
||||
{
|
||||
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
|
||||
///< per shuffle iteration.
|
||||
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
|
||||
///< per shuffle iteration.
|
||||
};
|
||||
|
||||
/// @brief Information for the output tile transfer (CShuffle).
|
||||
/// @details Describes how the final computed tile (C matrix) is written out from
|
||||
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
|
||||
struct OutputTileTransferInfo
|
||||
{
|
||||
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
|
||||
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
|
||||
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
|
||||
///< data into the output tensor.
|
||||
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
|
||||
///< output tensor.
|
||||
};
|
||||
|
||||
// Helper metafunctions to derive signature information from Instance types
|
||||
|
||||
/// @brief Helper function to report unsupported convolution direction with a clear error message.
|
||||
template <typename Instance>
|
||||
consteval void report_unsupported_conv_direction_error()
|
||||
{
|
||||
throw "Unsupported convolution direction detected!\n"
|
||||
"The kernel instance does not have a recognized convolution specialization.\n"
|
||||
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
|
||||
"kConvBwdWeightSpecialization.\n"
|
||||
"Please verify that your kernel instance is properly configured.";
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution direction from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
|
||||
@@ -273,7 +189,10 @@ constexpr builder::ConvDirection conv_direction()
|
||||
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
|
||||
return builder::ConvDirection::BACKWARD_WEIGHT;
|
||||
else
|
||||
return builder::ConvDirection::FORWARD; // Default fallback
|
||||
{
|
||||
report_unsupported_conv_direction_error<Instance>();
|
||||
return builder::ConvDirection::FORWARD; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
|
||||
@@ -296,6 +215,7 @@ constexpr auto conv_spec()
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter3x3: return FILTER_3x3;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
@@ -334,6 +254,20 @@ template <typename A,
|
||||
inline constexpr bool layouts_are =
|
||||
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
|
||||
|
||||
/// @brief Helper function to report unsupported layout combinations with a clear error message.
|
||||
/// @details This consteval function is designed to fail at compile time with a descriptive
|
||||
/// error message when an unsupported layout combination is encountered.
|
||||
template <typename A, typename B, typename E, int SpatialDim>
|
||||
consteval void report_unsupported_layout_error()
|
||||
{
|
||||
// This will produce a compile-time error with the exception message
|
||||
throw "Unsupported convolution layout combination detected!\n"
|
||||
"The combination of ALayout, BLayout, and ELayout template parameters\n"
|
||||
"is not recognized for the given spatial dimension.\n"
|
||||
"Please verify that your convolution instance uses a supported layout configuration.\n"
|
||||
"Check the conv_layout() function for the list of supported layout combinations.";
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
@@ -358,6 +292,8 @@ constexpr auto conv_layout()
|
||||
case 1:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
|
||||
return layouts(GNWC, GKXC, GNWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>)
|
||||
return layouts(GNWC, GKXC, GNWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
|
||||
return layouts(NWGC, GKXC, NWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
|
||||
@@ -368,8 +304,12 @@ constexpr auto conv_layout()
|
||||
case 2:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>)
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
|
||||
return layouts(NHWGC, GKYXC, NHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>)
|
||||
return layouts(NHWGC, GKYXC, NHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKYXC, NGKHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
|
||||
@@ -378,6 +318,8 @@ constexpr auto conv_layout()
|
||||
case 3:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
|
||||
return layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>)
|
||||
return layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
|
||||
return layouts(NDHWGC, GKZYXC, NDHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
|
||||
@@ -386,11 +328,31 @@ constexpr auto conv_layout()
|
||||
return layouts(NGCDHW, GKCZYX, NGKDHW);
|
||||
break;
|
||||
}
|
||||
|
||||
// If we reach here, the layout combination is not supported
|
||||
// Call consteval function to trigger a compile-time error with a clear message
|
||||
report_unsupported_layout_error<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
|
||||
|
||||
// This return is unreachable but needed to satisfy the compiler
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
}
|
||||
|
||||
/// @brief Helper function to report unsupported data type with a clear error message.
|
||||
template <typename ADataType>
|
||||
consteval void report_unsupported_data_type_error()
|
||||
{
|
||||
throw "Unsupported data type detected!\n"
|
||||
"The ADataType is not recognized.\n"
|
||||
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
|
||||
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
|
||||
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
|
||||
"(BF8), "
|
||||
"int8_t (I8), ck::Tuple<int8_t, int8_t> (I8_I8), uint8_t (U8).\n"
|
||||
"Please verify that your kernel instance uses a supported data type.";
|
||||
}
|
||||
|
||||
/// @brief Derives the data type from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
|
||||
/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8).
|
||||
template <typename Instance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
requires HasDataTypes<InstanceTraits<Instance>>
|
||||
@@ -401,18 +363,50 @@ constexpr builder::DataType conv_data_type()
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck::half_t>)
|
||||
return FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
|
||||
return FP16_FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
|
||||
return BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
|
||||
return BF16_BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, float>)
|
||||
return FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
|
||||
return FP32_FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, double>)
|
||||
return FP64;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
|
||||
return FP8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, int8_t>)
|
||||
return I8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
|
||||
return I8_I8;
|
||||
else if constexpr(std::is_same_v<ADataType, uint8_t>)
|
||||
return U8;
|
||||
else
|
||||
return FP32; // Default fallback
|
||||
{
|
||||
report_unsupported_data_type_error<ADataType>();
|
||||
return FP32; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Helper function to report unsupported elementwise operation with a clear error message.
|
||||
template <typename ElementwiseOp>
|
||||
consteval void report_unsupported_elementwise_op_error()
|
||||
{
|
||||
throw "Unsupported elementwise operation detected!\n"
|
||||
"The elementwise operation type is not recognized.\n"
|
||||
"Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, "
|
||||
"BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, "
|
||||
"ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, "
|
||||
"UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, "
|
||||
"Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, "
|
||||
"UnaryConvert.\n"
|
||||
"Please verify that your kernel instance uses a supported elementwise operation.";
|
||||
}
|
||||
|
||||
/// @brief Derives the elementwise operation from op type.
|
||||
@@ -424,16 +418,83 @@ constexpr builder::ElementwiseOperation elementwise_op()
|
||||
using enum builder::ElementwiseOperation;
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
|
||||
if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
if constexpr(detail::case_insensitive_equal(name, "AddClamp"))
|
||||
return ADD_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd"))
|
||||
return ADD_RELU_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
return BIAS_BNORM_CLAMP;
|
||||
if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
|
||||
return BILINEAR;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp"))
|
||||
return BIAS_BNORM_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
return CLAMP;
|
||||
if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale"))
|
||||
return CONV_INVSCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScale"))
|
||||
return CONV_SCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd"))
|
||||
return CONV_SCALE_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu"))
|
||||
return CONV_SCALE_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
return SCALE;
|
||||
if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd"))
|
||||
return SCALE_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
return PASS_THROUGH;
|
||||
if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
return SCALEADD_SCALEADD_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp"))
|
||||
return DYNAMIC_UNARY_OP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp"))
|
||||
return UNARY_COMBINED_OP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp"))
|
||||
return ACTIVATION_MUL2_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp"))
|
||||
return ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp"))
|
||||
return ADD_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp"))
|
||||
return ADD_ACTIVATION_MUL2_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp"))
|
||||
return ADD_MUL_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp"))
|
||||
return ADD_MUL2_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert"))
|
||||
return UNARY_CONVERT;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Logistic"))
|
||||
return LOGISTIC;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu"))
|
||||
return CLIPPED_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Swish"))
|
||||
return SWISH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Elu"))
|
||||
return ELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Power"))
|
||||
return POWER;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu"))
|
||||
return LEAKY_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs"))
|
||||
return UNARY_ABS;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Relu"))
|
||||
return RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "SoftRelu"))
|
||||
return SOFT_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Sigmoid"))
|
||||
return SIGMOID;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "TanH"))
|
||||
return TANH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Gelu"))
|
||||
return GELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Silu"))
|
||||
return SILU;
|
||||
else
|
||||
{
|
||||
report_unsupported_elementwise_op_error<ElementwiseOp>();
|
||||
return PASS_THROUGH; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives a gemm padding from a kernel instance type.
|
||||
@@ -606,45 +667,4 @@ struct ConvTraits<Instance>
|
||||
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
|
||||
};
|
||||
|
||||
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
|
||||
/// @details This specialization provides backward compatibility for reflecting
|
||||
/// on kernels defined via the `ConvBuilder` interface. It works by first
|
||||
/// creating the `Instance` via the builder, and then delegating
|
||||
/// all trait extraction to the `ConvTraits<Instance>` specialization.
|
||||
template <builder::ConvSignatureDescriptor auto SIGNATURE,
|
||||
builder::ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
builder::StringLiteral VERSION>
|
||||
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
|
||||
{
|
||||
using Instance = typename builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>::Instance;
|
||||
|
||||
// Delegate to Instance-based ConvTraits
|
||||
using InstanceConvTraits = ConvTraits<Instance>;
|
||||
|
||||
// Forward all members from Instance-based traits
|
||||
static constexpr int spatial_dim = InstanceConvTraits::spatial_dim;
|
||||
static constexpr builder::ConvDirection direction = InstanceConvTraits::direction;
|
||||
static constexpr auto layout = InstanceConvTraits::layout;
|
||||
static constexpr builder::DataType data_type = InstanceConvTraits::data_type;
|
||||
|
||||
static constexpr builder::ElementwiseOperation input_element_op =
|
||||
InstanceConvTraits::input_element_op;
|
||||
static constexpr builder::ElementwiseOperation weight_element_op =
|
||||
InstanceConvTraits::weight_element_op;
|
||||
static constexpr builder::ElementwiseOperation output_element_op =
|
||||
InstanceConvTraits::output_element_op;
|
||||
|
||||
static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding;
|
||||
static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization;
|
||||
|
||||
static constexpr int thread_block_size = InstanceConvTraits::thread_block_size;
|
||||
static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims;
|
||||
static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer;
|
||||
static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer;
|
||||
static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm;
|
||||
static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer;
|
||||
static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version;
|
||||
static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/// @file
|
||||
/// @brief Type definitions for convolution reflection
|
||||
///
|
||||
/// This file contains the type definitions used by both conv_traits.hpp and conv_description.hpp
|
||||
/// to avoid circular dependencies.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Data tile dimensions processed by a workgroup.
|
||||
/// @details This struct defines the M, N, and K dimensions of the data tile
|
||||
/// that a single workgroup (thread block) is responsible for processing in the
|
||||
/// underlying GEMM computation.
|
||||
struct DataTileInfo
|
||||
{
|
||||
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
|
||||
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
|
||||
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
|
||||
};
|
||||
|
||||
/// @brief Dimensions for an input data tile transfer.
|
||||
/// @details Defines the shape of the input tile (A or B matrix) as it is
|
||||
/// transferred from global memory to LDS. The tile is conceptually divided
|
||||
/// into k0 and k1 dimensions.
|
||||
struct InputTileTransferDimensions
|
||||
{
|
||||
int k0; ///< The outer dimension of K, where K = k0 * k1.
|
||||
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
|
||||
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
|
||||
///< memory.
|
||||
};
|
||||
|
||||
/// @brief Parameters governing the transfer of an input tile.
|
||||
/// @details This struct holds configuration details for how an input tile is
|
||||
/// loaded from global memory into LDS, including thread clustering, memory
|
||||
/// access patterns, and vectorization settings.
|
||||
struct InputTileTransferParams
|
||||
{
|
||||
int k1; ///< The inner K dimension size, often matching the vectorization width.
|
||||
std::array<int, 3>
|
||||
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
|
||||
///< many threads are arranged on each axis.
|
||||
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
|
||||
///< input tensor dimensions.
|
||||
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
|
||||
///< dimension to read first).
|
||||
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
|
||||
///< (the contiguous dimension).
|
||||
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
|
||||
///< elements accessed per thread per instruction.
|
||||
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
|
||||
///< dimension.
|
||||
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
|
||||
///< conflicts.
|
||||
};
|
||||
|
||||
/// @brief Complete information for an input tile transfer.
|
||||
/// @details Combines the dimensional information and transfer parameters for
|
||||
/// a full description of an input tile's journey from global memory to LDS.
|
||||
struct InputTileTransferInfo
|
||||
{
|
||||
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
|
||||
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
|
||||
};
|
||||
|
||||
/// @brief Parameters for the warp-level GEMM computation.
|
||||
/// @details Defines the configuration of the GEMM operation performed by each
|
||||
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
|
||||
struct WarpGemmParams
|
||||
{
|
||||
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
|
||||
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
|
||||
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
|
||||
///< wavefront (MXdlPerWave).
|
||||
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
|
||||
///< wavefront (NXdlPerWave).
|
||||
};
|
||||
|
||||
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
|
||||
/// @details Configures how many MFMA instruction results are processed per
|
||||
/// wave in each iteration of the CShuffle routine.
|
||||
struct WarpShuffleParams
|
||||
{
|
||||
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
|
||||
///< per shuffle iteration.
|
||||
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
|
||||
///< per shuffle iteration.
|
||||
};
|
||||
|
||||
/// @brief Information for the output tile transfer (CShuffle).
|
||||
/// @details Describes how the final computed tile (C matrix) is written out from
|
||||
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
|
||||
struct OutputTileTransferInfo
|
||||
{
|
||||
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
|
||||
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
|
||||
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
|
||||
///< data into the output tensor.
|
||||
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
|
||||
///< output tensor.
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -20,6 +20,11 @@ namespace ck_tile::reflect {
|
||||
class Description
|
||||
{
|
||||
public:
|
||||
Description() = default;
|
||||
Description(const Description&) = default;
|
||||
Description(Description&&) = default;
|
||||
Description& operator=(const Description&) = default;
|
||||
Description& operator=(Description&&) = default;
|
||||
/// @brief Virtual destructor for proper cleanup of derived classes
|
||||
virtual ~Description() = default;
|
||||
|
||||
@@ -36,4 +41,30 @@ class Description
|
||||
virtual std::string instance_string() const = 0;
|
||||
};
|
||||
|
||||
/// @brief A specialized Description that only supports instance_string()
|
||||
/// This is a helper class for kernels that don't yet have full ConvDescription support.
|
||||
/// The brief() and detailed() methods return "not supported" placeholders.
|
||||
class InstanceStringDescription : public Description
|
||||
{
|
||||
public:
|
||||
/// @brief Construct with an instance string
|
||||
/// @param instance The instance string to store
|
||||
explicit InstanceStringDescription(std::string instance) : instance_(std::move(instance)) {}
|
||||
|
||||
/// @brief Returns "not supported" as brief descriptions are not implemented
|
||||
/// @return A placeholder string indicating the feature is not supported
|
||||
std::string brief() const override { return "not supported"; }
|
||||
|
||||
/// @brief Returns "not supported" as detailed descriptions are not implemented
|
||||
/// @return A placeholder string indicating the feature is not supported
|
||||
std::string detailed() const override { return "not supported"; }
|
||||
|
||||
/// @brief Returns the stored instance string
|
||||
/// @return The instance string provided during construction
|
||||
std::string instance_string() const override { return instance_; }
|
||||
|
||||
private:
|
||||
std::string instance_; ///< The stored instance string
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
|
||||
@@ -11,15 +11,22 @@
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// TODO: Handle tuple types and FP8/BF8 properly
|
||||
enum class DataType
|
||||
{
|
||||
UNDEFINED_DATA_TYPE = 0,
|
||||
FP32,
|
||||
FP32_FP32,
|
||||
FP16,
|
||||
FP16_FP16,
|
||||
BF16,
|
||||
BF16_BF16,
|
||||
FP8,
|
||||
BF8,
|
||||
FP64,
|
||||
INT32,
|
||||
I8,
|
||||
I8_I8,
|
||||
U8
|
||||
};
|
||||
|
||||
@@ -102,13 +109,44 @@ enum class ConvDirection
|
||||
};
|
||||
|
||||
// Fused element-wise operations.
|
||||
// TODO: Generalize design rather than enumerating all possible ops.
|
||||
enum class ElementwiseOperation
|
||||
{
|
||||
ADD_CLAMP,
|
||||
ADD_RELU_ADD,
|
||||
ACTIVATION_MUL2_CLAMP,
|
||||
ACTIVATION_MUL_CLAMP,
|
||||
ADD_ACTIVATION_MUL_CLAMP,
|
||||
ADD_ACTIVATION_MUL2_CLAMP,
|
||||
ADD_MUL_ACTIVATION_MUL_CLAMP,
|
||||
ADD_MUL2_ACTIVATION_MUL_CLAMP,
|
||||
BIAS_BNORM_CLAMP,
|
||||
BILINEAR,
|
||||
SCALE,
|
||||
SCALE_ADD,
|
||||
CLAMP,
|
||||
CONV_INVSCALE,
|
||||
CONV_SCALE,
|
||||
CONV_SCALE_ADD,
|
||||
CONV_SCALE_RELU,
|
||||
PASS_THROUGH,
|
||||
SCALEADD_SCALEADD_RELU
|
||||
SCALEADD_SCALEADD_RELU,
|
||||
DYNAMIC_UNARY_OP,
|
||||
UNARY_COMBINED_OP,
|
||||
UNARY_CONVERT,
|
||||
LOGISTIC,
|
||||
CLIPPED_RELU,
|
||||
SWISH,
|
||||
ELU,
|
||||
POWER,
|
||||
LEAKY_RELU,
|
||||
UNARY_ABS,
|
||||
RELU,
|
||||
SOFT_RELU,
|
||||
SIGMOID,
|
||||
TANH,
|
||||
GELU,
|
||||
SILU
|
||||
};
|
||||
|
||||
// Enums for pipeline versions & schedulers
|
||||
@@ -160,7 +198,8 @@ enum class ConvFwdSpecialization
|
||||
DEFAULT,
|
||||
FILTER_1X1_PAD0,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
FILTER_3x3
|
||||
FILTER_3x3,
|
||||
ODD_C
|
||||
};
|
||||
|
||||
// Enums for the backward data convolution specialization.
|
||||
@@ -219,11 +258,17 @@ inline std::string_view toString(DataType dt)
|
||||
switch(dt)
|
||||
{
|
||||
case FP16: return "FP16";
|
||||
case FP16_FP16: return "FP16_FP16";
|
||||
case FP32: return "FP32";
|
||||
case FP32_FP32: return "FP32_FP32";
|
||||
case BF16: return "BF16";
|
||||
case BF16_BF16: return "BF16_BF16";
|
||||
case FP8: return "FP8";
|
||||
case BF8: return "BF8";
|
||||
case FP64: return "FP64";
|
||||
case INT32: return "INT32";
|
||||
case I8: return "I8";
|
||||
case I8_I8: return "I8_I8";
|
||||
case U8: return "U8";
|
||||
case UNDEFINED_DATA_TYPE: return "UNDEFINED_DATA_TYPE";
|
||||
default: return "Unknown";
|
||||
@@ -247,11 +292,41 @@ inline std::string_view toString(ElementwiseOperation op)
|
||||
using enum ElementwiseOperation;
|
||||
switch(op)
|
||||
{
|
||||
case ADD_CLAMP: return "ADD_CLAMP";
|
||||
case ADD_RELU_ADD: return "ADD_RELU_ADD";
|
||||
case ACTIVATION_MUL2_CLAMP: return "ACTIVATION_MUL2_CLAMP";
|
||||
case ACTIVATION_MUL_CLAMP: return "ACTIVATION_MUL_CLAMP";
|
||||
case ADD_ACTIVATION_MUL_CLAMP: return "ADD_ACTIVATION_MUL_CLAMP";
|
||||
case ADD_ACTIVATION_MUL2_CLAMP: return "ADD_ACTIVATION_MUL2_CLAMP";
|
||||
case ADD_MUL_ACTIVATION_MUL_CLAMP: return "ADD_MUL_ACTIVATION_MUL_CLAMP";
|
||||
case ADD_MUL2_ACTIVATION_MUL_CLAMP: return "ADD_MUL2_ACTIVATION_MUL_CLAMP";
|
||||
case BIAS_BNORM_CLAMP: return "BIAS_BNORM_CLAMP";
|
||||
case BILINEAR: return "BILINEAR";
|
||||
case CLAMP: return "CLAMP";
|
||||
case SCALE: return "SCALE";
|
||||
case SCALE_ADD: return "SCALE_ADD";
|
||||
case CONV_INVSCALE: return "CONV_INVSCALE";
|
||||
case CONV_SCALE: return "CONV_SCALE";
|
||||
case CONV_SCALE_ADD: return "CONV_SCALE_ADD";
|
||||
case CONV_SCALE_RELU: return "CONV_SCALE_RELU";
|
||||
case PASS_THROUGH: return "PASS_THROUGH";
|
||||
case BIAS_BNORM_CLAMP: return "BIAS_BNORM_CLAMP";
|
||||
case SCALEADD_SCALEADD_RELU: return "SCALEADD_SCALEADD_RELU";
|
||||
case DYNAMIC_UNARY_OP: return "DYNAMIC_UNARY_OP";
|
||||
case UNARY_COMBINED_OP: return "UNARY_COMBINED_OP";
|
||||
case UNARY_CONVERT: return "UNARY_CONVERT";
|
||||
case LOGISTIC: return "LOGISTIC";
|
||||
case CLIPPED_RELU: return "CLIPPED_RELU";
|
||||
case SWISH: return "SWISH";
|
||||
case ELU: return "ELU";
|
||||
case POWER: return "POWER";
|
||||
case LEAKY_RELU: return "LEAKY_RELU";
|
||||
case UNARY_ABS: return "UNARY_ABS";
|
||||
case RELU: return "RELU";
|
||||
case SOFT_RELU: return "SOFT_RELU";
|
||||
case SIGMOID: return "SIGMOID";
|
||||
case TANH: return "TANH";
|
||||
case GELU: return "GELU";
|
||||
case SILU: return "SILU";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
@@ -305,6 +380,7 @@ inline std::string_view toString(ConvFwdSpecialization spec)
|
||||
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_3x3: return "FILTER_3x3";
|
||||
case ODD_C: return "ODD_C";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <gmock/gmock.h>
|
||||
#include <concepts>
|
||||
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include <ck_tile/builder/conv_builder.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_description.hpp>
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_description.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_describe.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
|
||||
@@ -72,14 +72,16 @@ std::string expected_str = "DeviceGroupedConvBwdWeight_Xdl_CShuffle"
|
||||
",1" // MaxTransposeTransferSrcScalarPerVector
|
||||
",1>"; // MaxTransposeTransferDstScalarPerVector
|
||||
|
||||
// Test GetInstanceString through base class pointer for backward weight XDL variant
|
||||
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForBwdWeightGrpConvXdl)
|
||||
// Test describe() through base class pointer for backward weight XDL variant
|
||||
TEST(InstanceString, DescribeReturnsCorrectValueForBwdWeightGrpConvXdl)
|
||||
{
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
DeviceInstance device_instance;
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
|
||||
auto desc = base_ptr->describe();
|
||||
ASSERT_NE(desc, nullptr);
|
||||
EXPECT_EQ(desc->instance_string(), expected_str);
|
||||
}
|
||||
|
||||
// TODO: Add DescriptionReturnsCorrectValueForBwdWeightGrpConvXdl test once ckr::describe supports
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_description.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include <ck_tile/builder/reflect/conv_describe.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_description.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_base.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -77,14 +78,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
||||
",Default" // LoopScheduler
|
||||
",1>"; // NumGroupsToMerge
|
||||
|
||||
// Test GetInstanceString through base class pointer for standard XDL variant
|
||||
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConv)
|
||||
// Test describe() through base class pointer for standard XDL variant
|
||||
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConv)
|
||||
{
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
DeviceInstance device_instance;
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
|
||||
auto desc = base_ptr->describe();
|
||||
ASSERT_NE(desc, nullptr);
|
||||
EXPECT_EQ(desc->instance_string(), expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConv)
|
||||
|
||||
@@ -71,14 +71,16 @@ std::string expected_str = "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
|
||||
",5" // CThreadTransferSrcDstVectorDim
|
||||
",1>"; // CThreadTransferDstScalarPerVector
|
||||
|
||||
// Test GetInstanceString through base class pointer for DL variant
|
||||
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvDl)
|
||||
// Test describe() through base class pointer for DL variant
|
||||
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvDl)
|
||||
{
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
DeviceInstance device_instance;
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
|
||||
auto desc = base_ptr->describe();
|
||||
ASSERT_NE(desc, nullptr);
|
||||
EXPECT_EQ(desc->instance_string(), expected_str);
|
||||
}
|
||||
|
||||
// TODO: Add DescriptionReturnsCorrectValueForFwdGrpConvDl test once ckr::describe supports DL
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_describe.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_description.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_base.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_base.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -76,14 +77,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Ten
|
||||
",fp16" // BComputeDataType
|
||||
",Default>"; // LoopScheduler
|
||||
|
||||
// Test GetInstanceString through base class pointer for large tensor variant
|
||||
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvLargeTensor)
|
||||
// Test describe() through base class pointer for large tensor variant
|
||||
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvLargeTensor)
|
||||
{
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
DeviceInstance device_instance;
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
|
||||
auto desc = base_ptr->describe();
|
||||
ASSERT_NE(desc, nullptr);
|
||||
EXPECT_EQ(desc->instance_string(), expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConvLargeTensor)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_describe.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_description.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_base.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp>
|
||||
@@ -78,14 +79,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
|
||||
",fp16" // BComputeDataType
|
||||
",false>"; // DirectLoad
|
||||
|
||||
// Test GetInstanceString through base class pointer for V3 variant
|
||||
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvV3)
|
||||
// Test describe() through base class pointer for V3 variant
|
||||
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvV3)
|
||||
{
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
DeviceInstance device_instance;
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
|
||||
auto desc = base_ptr->describe();
|
||||
ASSERT_NE(desc, nullptr);
|
||||
EXPECT_EQ(desc->instance_string(), expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConvV3)
|
||||
|
||||
@@ -76,14 +76,16 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
|
||||
",Default" // LoopSched
|
||||
",v1>"; // PipelineVer
|
||||
|
||||
// Test GetInstanceString through base class pointer for WMMA variant
|
||||
TEST(InstanceString, GetInstanceStringReturnsCorrectValueForFwdGrpConvWmma)
|
||||
// Test describe() through base class pointer for WMMA variant
|
||||
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvWmma)
|
||||
{
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
DeviceInstance device_instance;
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
EXPECT_EQ(base_ptr->GetInstanceString(), expected_str);
|
||||
auto desc = base_ptr->describe();
|
||||
ASSERT_NE(desc, nullptr);
|
||||
EXPECT_EQ(desc->instance_string(), expected_str);
|
||||
}
|
||||
|
||||
// TODO: Add DescriptionReturnsCorrectValueForFwdGrpConvWmma test once ckr::describe supports WMMA
|
||||
|
||||
@@ -8,8 +8,13 @@
|
||||
#include <sstream>
|
||||
#include <regex>
|
||||
#include <optional>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/stream_config.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
#endif
|
||||
#endif
|
||||
#include "ck/utility/get_id.hpp"
|
||||
|
||||
@@ -227,6 +232,12 @@ struct BaseOperator
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
|
||||
virtual std::string GetTypeString() const { return ""; }
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
// Return a description object for this operator, or nullptr if not supported.
|
||||
virtual std::unique_ptr<ck_tile::reflect::Description> describe() const { return nullptr; }
|
||||
#endif
|
||||
|
||||
virtual std::string GetInstanceString() const { return ""; }
|
||||
|
||||
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
@@ -1240,6 +1241,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
|
||||
}
|
||||
#endif
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#endif
|
||||
|
||||
@@ -1064,6 +1065,11 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/conv_describe.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
@@ -2080,6 +2081,21 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
static_assert(ck_tile::reflect::conv::HasConvTraits<DeviceOp>,
|
||||
"ConvTraits specialization not found for this device operation. "
|
||||
"If you modified the template parameters of this class, ensure that "
|
||||
"the corresponding ConvTraits specialization in "
|
||||
"ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that "
|
||||
"InstanceTraits in "
|
||||
"ck_tile/builder/reflect/"
|
||||
"instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp "
|
||||
"provides all required members for ConvTraits to work.");
|
||||
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
|
||||
ck_tile::reflect::describe<DeviceOp>());
|
||||
}
|
||||
#endif
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/conv_describe.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#endif
|
||||
|
||||
@@ -2103,6 +2104,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
|
||||
ck_tile::reflect::describe<DeviceOp>());
|
||||
}
|
||||
#endif
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
@@ -1019,6 +1020,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/conv_describe.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
@@ -1238,6 +1239,22 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
static_assert(
|
||||
ck_tile::reflect::conv::HasConvTraits<DeviceOp>,
|
||||
"ConvTraits specialization not found for this device operation. "
|
||||
"If you modified the template parameters of this class, ensure that "
|
||||
"the corresponding ConvTraits specialization in "
|
||||
"ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that "
|
||||
"InstanceTraits in "
|
||||
"ck_tile/builder/reflect/"
|
||||
"instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp "
|
||||
"provides all required members for ConvTraits to work.");
|
||||
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
|
||||
ck_tile::reflect::describe<DeviceOp>());
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user