[CK_BUILDER] Convert convolution traits to a struct with factory functions (#3547)

* Factor helpers out of conv_traits.hpp

* Create a non-templated conv_traits struct

* Migrate to new instance-specific instance_to_conv_traits functions

* Clean up reflection concepts

* Clean up ConvTraits helpers

* Update testing for convolution traits

This is a lot of cleanup on tests to have verbose coverage of feature
extraction, explicit tests for each supported device kernel, and
simple, readable test code.

* Address reviewer comments and resolve merge conflict
This commit is contained in:
John Shumway
2026-01-15 01:03:21 -08:00
committed by GitHub
parent 8705fdcb0c
commit 5122637215
17 changed files with 2288 additions and 1875 deletions

View File

@@ -7,43 +7,52 @@
#pragma once
#include "ck_tile/builder/reflect/conv_description.hpp"
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/instance_to_conv_traits.hpp"
namespace ck_tile::reflect {
/// @brief Factory function to create ConvDescription from a convolution instance type
/// @tparam Instance The convolution instance type (must have ConvTraits)
/// @return A ConvDescription object populated with the instance's configuration details
template <conv::HasConvTraits Instance>
/// @brief Concept to check if an Instance type has conv traits
template <typename Instance>
concept HasConvTraits = requires {
{ conv::instance_to_conv_traits<Instance>() };
};
/// Factory function to create ConvDescription from a convolution instance type
/// Instance The convolution instance type
/// A ConvDescription object populated with the instance's configuration details
///
/// TODO: Fix ConvDescription to just use the ConvTraits directly.
template <typename Instance>
requires HasConvTraits<Instance>
conv::ConvDescription describe()
{
using Traits = conv::ConvTraits<Instance>;
const auto traits = conv::instance_to_conv_traits<Instance>();
return conv::ConvDescription(
conv::ConvSignatureInfo{
.spatial_dim = Traits::spatial_dim,
.direction = Traits::direction,
.input_layout = Traits::layout[0],
.weight_layout = Traits::layout[1],
.output_layout = Traits::layout[2],
.data_type = Traits::data_type,
.input_element_op = Traits::input_element_op,
.weight_element_op = Traits::weight_element_op,
.output_element_op = Traits::output_element_op,
.spatial_dim = traits.spatial_dim,
.direction = traits.direction,
.input_layout = traits.layout[0],
.weight_layout = traits.layout[1],
.output_layout = traits.layout[2],
.data_type = traits.data_type,
.input_element_op = traits.input_element_op,
.weight_element_op = traits.weight_element_op,
.output_element_op = traits.output_element_op,
},
conv::GemmAlgorithmInfo{
.thread_block_size = Traits::thread_block_size,
.tile_dims = Traits::tile_dims,
.warp_gemm = Traits::warp_gemm,
.a_tile_transfer = Traits::a_tile_transfer,
.b_tile_transfer = Traits::b_tile_transfer,
.c_tile_transfer = Traits::c_tile_transfer,
.pipeline_version = Traits::pipeline_version,
.pipeline_scheduler = Traits::pipeline_scheduler,
.conv_specialization = Traits::conv_specialization,
.padding = Traits::gemm_padding,
.thread_block_size = traits.thread_block_size,
.tile_dims = traits.tile_dims,
.warp_gemm = traits.warp_gemm,
.a_tile_transfer = traits.a_tile_transfer,
.b_tile_transfer = traits.b_tile_transfer,
.c_tile_transfer = traits.c_tile_transfer,
.pipeline_version = traits.pipeline_version,
.pipeline_scheduler = traits.pipeline_scheduler,
.conv_specialization = traits.conv_specialization,
.padding = traits.gemm_padding,
},
[]() { return reflect::instance_string<Instance>(); });
[]<typename T = Instance>() { return reflect::instance_string<T>(); });
}
} // namespace ck_tile::reflect

View File

@@ -1,664 +1,109 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Runtime-accessible convolution kernel configuration data structure
//
// This file defines ConvTraits, a pure data structure that captures the complete
// configuration of a convolution kernel in a domain-specific abstraction, without
// requiring knowledge of the underlying kernel instance implementation details.
//
// ## Purpose and Design
//
// ConvTraits provides type erasure for convolution kernel configurations, allowing
// for reflection of convolution kernel objects. The struct represents kernel
// traits in terms of convolution-specific concepts for AMD GPUs rather than raw
// template parameters.
//
// ## Architecture and Usage
//
// ConvTraits sits at the center of the reflection system:
//
// 1. **Population**: Values are created by `instance_to_conv_traits()` template
// specializations that extract configuration from compile-time InstanceTraits
//
// 2. **Consumption**: Used by ConvDescription to provide human-readable descriptions
// of kernel configurations for debugging, logging, and documentation
//
// ## Structure Organization
//
// The struct separates kernel configuration into two logical categories:
//
// - **Signature Information**: Defines what the kernel computes (direction, layouts,
// data types, elementwise operations, specializations)
//
// - **Algorithm Information**: Defines how the kernel computes (thread block size,
// tile dimensions, memory access patterns, pipeline configuration)
//
// ## Evolution and Extensibility
//
// ConvTraits is designed to evolve through composition (not inheritance):
//
// - Currently supports XDL forward convolution kernels
// - Will extend to the other forward convolutions
// - Will be extended to cover backward data and backward weight convolutions
// - Will incorporate fusion operations and additional specializations
// - Uses std::optional and std::variant for optional/variant fields
// - Eventually will generalize to KernelTraits for GEMM, flash attention, etc.
#pragma once
#include <concepts>
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/pipeline_enum.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/reflect/conv_types.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
#include "ck_tile/builder/types.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
namespace ck_tile::reflect::conv {
// Forward convolution layout concept - checks for A/B/E layout types
template <typename T>
concept HasFwdConvLayouts = requires {
typename T::ALayout;
typename T::BLayout;
typename T::ELayout;
};
// GEMM specialization concept - checks for kGemmSpecialization member
template <typename T>
concept HasGemmSpec = requires {
{
T::kGemmSpecialization
} -> std::convertible_to<ck::tensor_operation::device::GemmSpecialization>;
};
// Data types concept - checks for ADataType member
template <typename T>
concept HasDataTypes = requires { typename T::ADataType; };
// Elementwise operations concept - checks for A/B/CDE elementwise operation types
template <typename T>
concept HasElementwiseOps = requires {
typename T::AElementwiseOperation;
typename T::BElementwiseOperation;
typename T::CDEElementwiseOperation;
};
// Tile parameters concept - checks for tile dimension and transfer members
template <typename T>
concept HasTileParams = requires {
{ T::kKPerBlock } -> std::convertible_to<int>;
{ T::kMPerBlock } -> std::convertible_to<int>;
{ T::kNPerBlock } -> std::convertible_to<int>;
{ T::kAK1 } -> std::convertible_to<int>;
{ T::kBK1 } -> std::convertible_to<int>;
T::kCThreadClusterLengths;
};
// Comprehensive concept that checks if an instance has all XDL forward convolution traits
// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions
template <typename T>
concept IsXdlFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasDataTypes<T> &&
HasElementwiseOps<T> && HasTileParams<T>;
// Primary concept for checking if a type can be described
// Currently only forward convolutions are supported, but this can be extended
// in the future to include backward data and backward weight convolutions
template <typename T>
concept HasConvTraits = IsXdlFwdConv<InstanceTraits<T>>;
// Helper metafunctions to convert from ck enums to builder enums
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5).
/// @details This function maps CK's block GEMM pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. The pipeline version
/// determines the strategy used for data movement and computation overlap in the
/// GEMM kernel's main loop.
template <ck::BlockGemmPipelineVersion ck_ver>
constexpr auto convert_pipeline_version()
// Runtime data structure representing a convolution kernel's complete configuration
//
// This pure data struct (no template parameters, no static members) provides
// type erasure for convolution kernel configurations. It can hold the configuration
// from any convolution kernel instance, enabling runtime storage, comparison, and
// manipulation of kernel properties.
//
// The struct is populated by `instance_to_conv_traits()` template specializations
// that extract compile-time configuration from InstanceTraits and convert it to
// this standardized runtime representation.
//
// Members are organized into two categories:
// - **Signature Information**: Defines the computational interface (what to compute)
// - **Algorithm Information**: Defines the implementation strategy (how to compute)
//
// Note: This struct will evolve to support additional convolution variants and
// eventually generalize to other kernel types through composition.
//
// There is a lot we still need to do:
//
// TODO: Generalize type support for all tensors and accumulator.
// TODO: Describe all tensros.
// TODO: Include the full generalization of the signature from the input schema.
// TODO: Include the full generalization of the algorithm from the input schema.
struct ConvTraits
{
using enum ck::BlockGemmPipelineVersion;
using enum builder::PipelineVersion;
switch(ck_ver)
{
case v1: return V1;
case v2: return V2;
case v3: return V3;
case v4: return V4;
case v5: return V5;
}
}
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY).
/// @details This function maps CK's general pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. Note that this overload
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
/// variant, including support for specialized weight-only pipelines.
template <ck::PipelineVersion ck_ver>
constexpr auto convert_pipeline_version()
{
using enum ck::PipelineVersion;
using enum builder::PipelineVersion;
switch(ck_ver)
{
case v1: return V1;
case v2: return V2;
case v4: return V4;
case weight_only: return WEIGHT_ONLY;
}
}
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE).
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
/// builder framework's standardized scheduler enum. The scheduler determines how work
/// is distributed and synchronized within and across wavefronts during pipeline execution.
/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates
/// across multiple wavefronts.
template <ck::BlockGemmPipelineScheduler ck_sched>
constexpr auto convert_pipeline_scheduler()
{
using enum ck::BlockGemmPipelineScheduler;
using enum builder::PipelineScheduler;
switch(ck_sched)
{
case Intrawave: return INTRAWAVE;
case Interwave: return INTERWAVE;
}
}
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE).
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
/// the main computational loop are scheduled across threads. DEFAULT uses the standard
/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved
/// performance in certain scenarios.
template <ck::LoopScheduler ck_sched>
constexpr auto convert_pipeline_scheduler()
{
using enum ck::LoopScheduler;
using enum builder::PipelineScheduler;
switch(ck_sched)
{
case Default: return DEFAULT;
case Interwave: return INTERWAVE;
}
}
// Helper metafunctions to derive signature information from Instance types
/// @brief Helper function to report unsupported convolution direction with a clear error message.
template <typename Instance>
[[noreturn]] consteval void report_unsupported_conv_direction_error()
{
throw "Unsupported convolution direction detected!\n"
"The kernel instance does not have a recognized convolution specialization.\n"
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
"kConvBwdWeightSpecialization.\n"
"Please verify that your kernel instance is properly configured.";
}
/// @brief Derives the convolution direction from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
template <typename Instance>
constexpr builder::ConvDirection conv_direction()
{
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
return builder::ConvDirection::FORWARD;
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
return builder::ConvDirection::BACKWARD_DATA;
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
return builder::ConvDirection::BACKWARD_WEIGHT;
else
{
report_unsupported_conv_direction_error<Instance>();
return builder::ConvDirection::FORWARD; // Unreachable
}
}
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvSpecialization` enum value.
template <typename Instance>
constexpr auto conv_spec()
{
using InstTraits = InstanceTraits<Instance>;
using enum builder::ConvSpecialization;
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
switch(InstTraits::kConvForwardSpecialization)
{
case Default: return DEFAULT;
case Filter1x1Pad0: return FILTER_1X1_PAD0;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
case Filter3x3: return FILTER_3x3;
case OddC: return ODD_C;
}
}
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
switch(InstTraits::kConvBwdDataSpecialization)
{
case Default: return DEFAULT;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
}
}
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
switch(InstTraits::kConvBwdWeightSpecialization)
{
case Default: return DEFAULT;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
case Filter1x1Pad0: return FILTER_1X1_PAD0;
case OddC: return ODD_C;
}
}
}
// Helper variable template to check if CK layout enums match
template <typename A,
typename B,
typename E,
typename ExpectedA,
typename ExpectedB,
typename ExpectedE>
inline constexpr bool layouts_are =
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
/// @brief Helper function to report unsupported layout combinations with a clear error message.
/// @details This consteval function is designed to fail at compile time with a descriptive
/// error message when an unsupported layout combination is encountered.
template <typename A, typename B, typename E, int SpatialDim>
[[noreturn]] consteval void report_unsupported_layout_error()
{
// This will produce a compile-time error with the exception message
throw "Unsupported convolution layout combination detected!\n"
"The combination of ALayout, BLayout, and ELayout template parameters\n"
"is not recognized for the given spatial dimension.\n"
"Please verify that your convolution instance uses a supported layout configuration.\n"
"Check the conv_layout() function for the list of supported layout combinations.";
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
/// index 0 -> Input layout
/// index 1 -> Weight layout
/// index 2 -> Output layout
template <typename Instance>
constexpr auto conv_layout()
requires HasFwdConvLayouts<InstanceTraits<Instance>>
{
// Helper lambda to construct layout array
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
using A = typename InstanceTraits<Instance>::ALayout;
using B = typename InstanceTraits<Instance>::BLayout;
using E = typename InstanceTraits<Instance>::ELayout;
namespace ctl = ck::tensor_layout::convolution;
using enum builder::TensorLayout;
switch(InstanceTraits<Instance>::kSpatialDim)
{
case 1:
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
return layouts(NWGC, GKXC, NWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
return layouts(NGCW, GKXC, NGKW);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKCX, ctl::NGKW>)
return layouts(NGCW, GKCX, NGKW);
break;
case 2:
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
return layouts(NGCHW, GKYXC, NGKHW);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
return layouts(NGCHW, GKCYX, NGKHW);
break;
case 3:
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
return layouts(NDHWGC, GKZYXC, NDHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
return layouts(NGCDHW, GKZYXC, NGKDHW);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>)
return layouts(NGCDHW, GKCZYX, NGKDHW);
break;
}
// If we reach here, the layout combination is not supported
// Call consteval function to trigger a compile-time error with a clear message
report_unsupported_layout_error<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
// This return is unreachable but needed to satisfy the compiler
return layouts(GNHWC, GKYXC, GNHWK);
}
/// @brief Helper function to report unsupported data type with a clear error message.
template <typename ADataType>
[[noreturn]] consteval void report_unsupported_data_type_error()
{
throw "Unsupported data type detected!\n"
"The ADataType is not recognized.\n"
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
"(BF8), "
"int8_t (I8), ck::Tuple<int8_t, int8_t> (I8_I8), uint8_t (U8).\n"
"Please verify that your kernel instance uses a supported data type.";
}
/// @brief Derives the data type from a device kernel `Instance` type.
/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8).
template <typename Instance>
constexpr builder::DataType conv_data_type()
requires HasDataTypes<InstanceTraits<Instance>>
{
using InstTraits = InstanceTraits<Instance>;
using ADataType = typename InstTraits::ADataType;
using enum builder::DataType;
if constexpr(std::is_same_v<ADataType, ck::half_t>)
return FP16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
return FP16_FP16;
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
return BF16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
return BF16_BF16;
else if constexpr(std::is_same_v<ADataType, float>)
return FP32;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
return FP32_FP32;
else if constexpr(std::is_same_v<ADataType, double>)
return FP64;
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
return FP8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, int8_t>)
return I8;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
return I8_I8;
else if constexpr(std::is_same_v<ADataType, uint8_t>)
return U8;
else
{
report_unsupported_data_type_error<ADataType>();
return FP32; // Unreachable
}
}
/// @brief Helper function to report unsupported elementwise operation with a clear error message.
template <typename ElementwiseOp>
[[noreturn]] consteval void report_unsupported_elementwise_op_error()
{
throw "Unsupported elementwise operation detected!\n"
"The elementwise operation type is not recognized.\n"
"Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, "
"BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, "
"ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, "
"UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, "
"Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, "
"UnaryConvert.\n"
"Please verify that your kernel instance uses a supported elementwise operation.";
}
/// @brief Derives the elementwise operation from op type.
/// @tparam ElementwiseOp Elementwise operation functor type.
/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation.
template <typename ElementwiseOp>
constexpr builder::ElementwiseOperation elementwise_op()
{
using enum builder::ElementwiseOperation;
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
if constexpr(detail::case_insensitive_equal(name, "AddClamp"))
return ADD_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd"))
return ADD_RELU_ADD;
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
return BIAS_BNORM_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
return BILINEAR;
else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp"))
return BIAS_BNORM_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
return CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale"))
return CONV_INVSCALE;
else if constexpr(detail::case_insensitive_equal(name, "ConvScale"))
return CONV_SCALE;
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd"))
return CONV_SCALE_ADD;
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu"))
return CONV_SCALE_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
return SCALE;
else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd"))
return SCALE_ADD;
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
return PASS_THROUGH;
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
return SCALEADD_SCALEADD_RELU;
else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp"))
return DYNAMIC_UNARY_OP;
else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp"))
return UNARY_COMBINED_OP;
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp"))
return ACTIVATION_MUL2_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp"))
return ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp"))
return ADD_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp"))
return ADD_ACTIVATION_MUL2_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp"))
return ADD_MUL_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp"))
return ADD_MUL2_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert"))
return UNARY_CONVERT;
else if constexpr(detail::case_insensitive_equal(name, "Logistic"))
return LOGISTIC;
else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu"))
return CLIPPED_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Swish"))
return SWISH;
else if constexpr(detail::case_insensitive_equal(name, "Elu"))
return ELU;
else if constexpr(detail::case_insensitive_equal(name, "Power"))
return POWER;
else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu"))
return LEAKY_RELU;
else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs"))
return UNARY_ABS;
else if constexpr(detail::case_insensitive_equal(name, "Relu"))
return RELU;
else if constexpr(detail::case_insensitive_equal(name, "SoftRelu"))
return SOFT_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Sigmoid"))
return SIGMOID;
else if constexpr(detail::case_insensitive_equal(name, "TanH"))
return TANH;
else if constexpr(detail::case_insensitive_equal(name, "Gelu"))
return GELU;
else if constexpr(detail::case_insensitive_equal(name, "Silu"))
return SILU;
else
{
report_unsupported_elementwise_op_error<ElementwiseOp>();
return PASS_THROUGH; // Unreachable
}
}
/// @brief Derives a gemm padding from a kernel instance type.
/// @tparam Instance - A Device Kernel object type.
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
template <typename Instance>
constexpr builder::GemmPadding gemm_spec()
requires HasGemmSpec<InstanceTraits<Instance>>
{
using InstTraits = InstanceTraits<Instance>;
using enum builder::GemmPadding;
using enum ck::tensor_operation::device::GemmSpecialization;
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
switch(gemm_spec)
{
case Default: return DEFAULT;
case MPadding: return M_PADDING;
case NPadding: return N_PADDING;
case KPadding: return K_PADDING;
case MNPadding: return MN_PADDING;
case MKPadding: return MK_PADDING;
case NKPadding: return NK_PADDING;
case MNKPadding: return MNK_PADDING;
case OPadding: return O_PADDING;
case MOPadding: return MO_PADDING;
case NOPadding: return NO_PADDING;
case KOPadding: return KO_PADDING;
case MNOPadding: return MNO_PADDING;
case MKOPadding: return MKO_PADDING;
case NKOPadding: return NKO_PADDING;
case MNKOPadding: return MNKO_PADDING;
}
}
/// @brief Primary template for extracting convolution traits.
/// @details This struct is the main entry point for reflecting on a convolution
/// kernel's properties. It is specialized to handle different kinds of input types.
template <typename T>
struct ConvTraits;
/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`.
/// @details This is the primary specialization used to extract a comprehensive
/// set of traits directly from a fully-formed device kernel `Instance` type.
/// It uses `InstanceTraits` to access the kernel's template parameters.
template <HasInstanceTraits Instance>
requires IsXdlFwdConv<InstanceTraits<Instance>>
struct ConvTraits<Instance>
{
using InstTraits = InstanceTraits<Instance>;
// --- Signature Information ---
/// @brief The number of spatial dimensions in the convolution (1, 2, or 3).
static constexpr int spatial_dim = InstTraits::kSpatialDim;
/// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight).
static constexpr builder::ConvDirection direction = conv_direction<Instance>();
/// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK).
static constexpr auto layout = conv_layout<Instance>();
/// @brief The primary data type used in the computation (e.g., FP16, FP32).
static constexpr builder::DataType data_type = conv_data_type<Instance>();
int spatial_dim;
builder::ConvDirection direction;
std::array<builder::TensorLayout, 3> layout; // [input, weight, output]
builder::DataType data_type;
static constexpr builder::ElementwiseOperation input_element_op =
elementwise_op<typename InstTraits::AElementwiseOperation>();
static constexpr builder::ElementwiseOperation weight_element_op =
elementwise_op<typename InstTraits::BElementwiseOperation>();
static constexpr builder::ElementwiseOperation output_element_op =
elementwise_op<typename InstTraits::CDEElementwiseOperation>();
builder::ElementwiseOperation input_element_op;
builder::ElementwiseOperation weight_element_op;
builder::ElementwiseOperation output_element_op;
/// @brief The GEMM specialization used by the kernel - padding
static constexpr auto gemm_padding = gemm_spec<Instance>();
/// @brief The convolution-specific specialization (e.g., Default, 1x1).
static constexpr auto conv_specialization = conv_spec<Instance>();
builder::GemmPadding gemm_padding;
builder::ConvSpecialization conv_specialization;
// --- Algorithm Information ---
/// @brief The total number of threads in a thread block (workgroup).
static constexpr int thread_block_size = InstTraits::kBlockSize;
/// @brief The dimensions of the data tile processed by the thread block.
static constexpr DataTileInfo tile_dims = {
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
int thread_block_size;
DataTileInfo tile_dims;
/// @brief Configuration for the A-matrix (input) tile transfer.
static constexpr InputTileTransferInfo a_tile_transfer = {
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
InputTileTransferInfo a_tile_transfer;
InputTileTransferInfo b_tile_transfer;
/// @brief Configuration for the B-matrix (weights) tile transfer.
static constexpr InputTileTransferInfo b_tile_transfer = {
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
WarpGemmParams warp_gemm;
/// @brief Parameters for the warp-level GEMM computation.
static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave};
OutputTileTransferInfo c_tile_transfer;
/// @brief Configuration for the C-matrix (output) tile transfer.
static constexpr OutputTileTransferInfo c_tile_transfer = {
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
/// @brief Helper to safely get the pipeline version.
/// @details This is only available for some convolutions (e.g., forward).
/// If not present in `InstanceTraits`, it returns a default value.
template <typename T = InstTraits>
static constexpr auto get_pipeline_version()
{
if constexpr(requires { T::kPipelineVersion; })
{
return convert_pipeline_version<T::kPipelineVersion>();
}
else
{
// Return a default or indicate not available
return builder::PipelineVersion::V1;
}
}
/// @brief The block GEMM pipeline version used by the kernel.
static constexpr auto pipeline_version = get_pipeline_version();
/// @brief Helper to safely get the pipeline scheduler.
/// @details This is only available for some convolutions. If not present
/// in `InstanceTraits`, it returns a default value.
template <typename T = InstTraits>
static constexpr auto get_pipeline_scheduler()
{
if constexpr(requires { T::kPipelineScheduler; })
{
return convert_pipeline_scheduler<T::kPipelineScheduler>();
}
else if constexpr(requires { T::kLoopScheduler; })
{
return convert_pipeline_scheduler<T::kLoopScheduler>();
}
else
{
// Return a default or indicate not available
return builder::PipelineScheduler::DEFAULT;
}
}
/// @brief The pipeline scheduler used by the kernel.
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
};
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,84 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = {.m = InstTraits::kMPerBlock,
.n = InstTraits::kNPerBlock,
.k = InstTraits::kKPerBlock},
.a_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
.b_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave},
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle =
InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,84 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = {.m = InstTraits::kMPerBlock,
.n = InstTraits::kNPerBlock,
.k = InstTraits::kKPerBlock},
.a_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
.b_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave},
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle =
InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,84 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = {.m = InstTraits::kMPerBlock,
.n = InstTraits::kNPerBlock,
.k = InstTraits::kKPerBlock},
.a_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
.b_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave},
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle =
InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,739 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <array>
#include <concepts>
#include <string_view>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/pipeline_enum.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/reflect/conv_types.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
#include "ck_tile/builder/types.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
/// @file conv_traits_helpers.hpp
/// @brief Helper utilities for extracting convolution traits from kernel instances
///
/// This file provides compile-time reflection utilities to extract configuration
/// information from CK convolution kernel instances and convert them to the builder
/// framework's standardized representation.
///
/// ## Organization
///
/// The file is organized into the following sections:
///
/// 1. **Enum Conversions**: Functions to convert CK enums to builder enums
/// - Pipeline version conversions (BlockGemmPipelineVersion, PipelineVersion)
/// - Pipeline scheduler conversions (BlockGemmPipelineScheduler, LoopScheduler)
///
/// 2. **Signature Derivation**: Functions to extract signature information from instances
/// - Convolution direction (conv_direction)
/// - Convolution specialization (conv_spec)
/// - Tensor layouts (conv_layout)
/// - Data types (conv_data_type)
/// - Elementwise operations (elementwise_op)
/// - GEMM padding (gemm_spec)
///
/// 3. **Pipeline Configuration Helpers**: Safe extraction of pipeline parameters
/// - Pipeline version extraction (get_pipeline_version)
/// - Pipeline scheduler extraction (get_pipeline_scheduler)
///
/// ## Error Handling Strategy
///
/// This file uses a specific error handling pattern for compile-time errors:
/// - **consteval functions with throw**: Used for error reporting to ensure SFINAE doesn't
/// silently ignore errors. The thrown string becomes part of the compiler error message,
/// providing clear context to developers.
/// - **DO NOT replace with static_assert**: static_assert is silently ignored during SFINAE,
/// which would hide errors instead of reporting them clearly.
///
/// @example
/// ```cpp
/// using Instance =
/// ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<...>;
///
/// // Extract convolution direction
/// constexpr auto dir = conv_direction<Instance>();
///
/// // Extract data type
/// constexpr auto dtype = conv_data_type<Instance>();
///
/// // Extract layout configuration
/// constexpr auto layouts = conv_layout<Instance>();
/// ```
namespace ck_tile::reflect::conv {
// ============================================================================
// SECTION 1: ENUM CONVERSIONS
// ============================================================================
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value.
/// @details This function maps CK's block GEMM pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. The pipeline version
/// determines the strategy used for data movement and computation overlap in the
/// GEMM kernel's main loop.
///
/// Supported mappings:
/// - v1 -> V1
/// - v2 -> V2
/// - v3 -> V3
/// - v4 -> V4
/// - v5 -> V5
template <ck::BlockGemmPipelineVersion ck_ver>
constexpr builder::PipelineVersion convert_pipeline_version()
{
using enum ck::BlockGemmPipelineVersion;
using enum builder::PipelineVersion;
switch(ck_ver)
{
case v1: return V1;
case v2: return V2;
case v3: return V3;
case v4: return V4;
case v5: return V5;
}
}
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value.
/// @details This function maps CK's general pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. Note that this overload
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
/// variant, including support for specialized weight-only pipelines.
///
/// Supported mappings:
/// - v1 -> V1
/// - v2 -> V2
/// - v4 -> V4
/// - weight_only -> WEIGHT_ONLY
template <ck::PipelineVersion ck_ver>
constexpr builder::PipelineVersion convert_pipeline_version()
{
using enum ck::PipelineVersion;
using enum builder::PipelineVersion;
switch(ck_ver)
{
case v1: return V1;
case v2: return V2;
case v4: return V4;
case weight_only: return WEIGHT_ONLY;
}
}
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value.
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
/// builder framework's standardized scheduler enum. The scheduler determines how work
/// is distributed and synchronized within and across wavefronts during pipeline execution.
///
/// Supported mappings:
/// - Intrawave -> INTRAWAVE: Scheduling within a single wavefront
/// - Interwave -> INTERWAVE: Coordination across multiple wavefronts
template <ck::BlockGemmPipelineScheduler ck_sched>
constexpr builder::PipelineScheduler convert_pipeline_scheduler()
{
using enum ck::BlockGemmPipelineScheduler;
using enum builder::PipelineScheduler;
switch(ck_sched)
{
case Intrawave: return INTRAWAVE;
case Interwave: return INTERWAVE;
}
}
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value.
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
/// the main computational loop are scheduled across threads.
///
/// Supported mappings:
/// - Default -> DEFAULT: Standard scheduling strategy
/// - Interwave -> INTERWAVE: Cross-wavefront coordination for improved performance
template <ck::LoopScheduler ck_sched>
constexpr builder::PipelineScheduler convert_pipeline_scheduler()
{
using enum ck::LoopScheduler;
using enum builder::PipelineScheduler;
switch(ck_sched)
{
case Default: return DEFAULT;
case Interwave: return INTERWAVE;
}
}
// ============================================================================
// SECTION 2: SIGNATURE DERIVATION FUNCTIONS
// ============================================================================
// ----------------------------------------------------------------------------
// Convolution Direction
// ----------------------------------------------------------------------------
/// @brief Helper function to report unsupported convolution direction with a clear error message.
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
template <typename Instance>
[[noreturn]] consteval void report_unsupported_conv_direction_error()
{
throw "Unsupported convolution direction detected!\n"
"The kernel instance does not have a recognized convolution specialization.\n"
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
"kConvBwdWeightSpecialization.\n"
"Please verify that your kernel instance is properly configured.";
}
/// @brief Derives the convolution direction from a device kernel Instance type.
/// @tparam Instance The device kernel instance type.
/// @return A builder::ConvDirection enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
/// @details This function inspects the Instance's InstanceTraits to determine which
/// convolution specialization field is present, and returns the corresponding direction.
///
/// The function checks for the presence of:
/// - kConvForwardSpecialization -> FORWARD
/// - kConvBwdDataSpecialization -> BACKWARD_DATA
/// - kConvBwdWeightSpecialization -> BACKWARD_WEIGHT
///
/// @note Compilation will fail with a clear error message if the instance does not
/// have a recognized convolution specialization field.
template <typename Instance>
constexpr builder::ConvDirection conv_direction()
{
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
return builder::ConvDirection::FORWARD;
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
return builder::ConvDirection::BACKWARD_DATA;
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
return builder::ConvDirection::BACKWARD_WEIGHT;
else
{
report_unsupported_conv_direction_error<Instance>();
return builder::ConvDirection::FORWARD; // Unreachable
}
}
// ----------------------------------------------------------------------------
// Convolution Specialization
// ----------------------------------------------------------------------------
/// @brief Helper function to report unsupported convolution specialization with a clear error
/// message.
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
template <typename Instance>
[[noreturn]] consteval void report_unsupported_conv_spec_error()
{
throw "Unsupported convolution specialization detected!\n"
"The kernel instance does not have a recognized convolution specialization field.\n"
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
"kConvBwdWeightSpecialization.\n"
"Please verify that your kernel instance is properly configured.";
}
/// @brief Derives the convolution-specific specialization from a device kernel Instance type.
/// @tparam Instance The device kernel instance type.
/// @return A builder::ConvSpecialization enum value.
/// @details This function extracts the specialization enum from the Instance's InstanceTraits
/// and converts it to the corresponding builder framework enum.
///
/// For forward convolutions, supported specializations include:
/// - Default, Filter1x1Pad0, Filter1x1Stride1Pad0, Filter3x3, OddC
///
/// For backward data convolutions:
/// - Default, Filter1x1Stride1Pad0
///
/// For backward weight convolutions:
/// - Default, Filter1x1Stride1Pad0, Filter1x1Pad0, OddC
template <typename Instance>
constexpr builder::ConvSpecialization conv_spec()
{
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
using enum builder::ConvSpecialization;
switch(InstTraits::kConvForwardSpecialization)
{
case Default: return DEFAULT;
case Filter1x1Pad0: return FILTER_1X1_PAD0;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
case Filter3x3: return FILTER_3x3;
case OddC: return ODD_C;
}
}
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
using enum builder::ConvSpecialization;
switch(InstTraits::kConvBwdDataSpecialization)
{
case Default: return DEFAULT;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
}
}
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
using enum builder::ConvSpecialization;
switch(InstTraits::kConvBwdWeightSpecialization)
{
case Default: return DEFAULT;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
case Filter1x1Pad0: return FILTER_1X1_PAD0;
case OddC: return ODD_C;
}
}
else
{
report_unsupported_conv_spec_error<Instance>();
return builder::ConvSpecialization::DEFAULT; // Unreachable
}
}
// ----------------------------------------------------------------------------
// Tensor Layouts
// ----------------------------------------------------------------------------
/// @brief Helper function to report unsupported layout combinations with a clear error message.
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
template <typename A, typename B, typename E, int SpatialDim>
[[noreturn]] consteval void report_unsupported_layout_error()
{
throw "Unsupported convolution layout combination detected!\n"
"The combination of ALayout, BLayout, and ELayout template parameters\n"
"is not recognized for the given spatial dimension.\n"
"Please verify that your convolution instance uses a supported layout configuration.\n"
"Check the conv_layout() function for the list of supported layout combinations.";
}
/// @brief Derives the grouped convolution layout from a device kernel Instance type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array<builder::TensorLayout, 3> containing the layouts for:
/// - [0] Input tensor layout
/// - [1] Weight tensor layout
/// - [2] Output tensor layout
/// @details This function examines the Instance's ALayout, BLayout, and ELayout types
/// along with the spatial dimension to determine the appropriate layout configuration.
///
/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions).
/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants.
///
/// @note Compilation will fail with a clear error message if the layout combination
/// is not supported for the given spatial dimension.
///
/// TODO: If we don't check for supported layouts, this function can be simplified.
template <typename Instance>
constexpr std::array<builder::TensorLayout, 3> conv_layout()
{
using InstTraits = InstanceTraits<Instance>;
using A = typename InstTraits::ALayout;
using B = typename InstTraits::BLayout;
using E = typename InstTraits::ELayout;
namespace ctl = ck::tensor_layout::convolution;
using enum builder::TensorLayout;
// Helper to check if layouts match expected types
constexpr auto layouts_match = []<typename ExpA, typename ExpB, typename ExpE>() {
return std::is_same_v<A, ExpA> && std::is_same_v<B, ExpB> && std::is_same_v<E, ExpE>;
};
// Helper to construct layout array
constexpr auto make_layouts = [](auto in, auto weight, auto out) {
return std::array<builder::TensorLayout, 3>{in, weight, out};
};
constexpr int spatial_dim = InstTraits::kSpatialDim;
if constexpr(spatial_dim == 1)
{
if constexpr(layouts_match.template operator()<ctl::GNWC, ctl::GKXC, ctl::GNWK>())
return make_layouts(GNWC, GKXC, GNWK);
else if constexpr(layouts_match
.template operator()<ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>())
return make_layouts(GNWC, GKXC, GNWK);
else if constexpr(layouts_match.template operator()<ctl::NWGC, ctl::GKXC, ctl::NWGK>())
return make_layouts(NWGC, GKXC, NWGK);
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKXC, ctl::NGKW>())
return make_layouts(NGCW, GKXC, NGKW);
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKCX, ctl::NGKW>())
return make_layouts(NGCW, GKCX, NGKW);
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNWC, GKXC, GNWK); // Unreachable
}
}
else if constexpr(spatial_dim == 2)
{
if constexpr(layouts_match.template operator()<ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>())
return make_layouts(GNHWC, GKYXC, GNHWK);
else if constexpr(layouts_match
.template operator()<ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>())
return make_layouts(GNHWC, GKYXC, GNHWK);
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>())
return make_layouts(NHWGC, GKYXC, NHWGK);
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>())
return make_layouts(NHWGC, GKYXC, NHWGK);
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>())
return make_layouts(NGCHW, GKYXC, NGKHW);
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>())
return make_layouts(NGCHW, GKCYX, NGKHW);
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
}
}
else if constexpr(spatial_dim == 3)
{
if constexpr(layouts_match.template operator()<ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>())
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
else if constexpr(layouts_match
.template operator()<ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>())
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
else if constexpr(layouts_match
.template operator()<ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>())
return make_layouts(NDHWGC, GKZYXC, NDHWGK);
else if constexpr(layouts_match
.template operator()<ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>())
return make_layouts(NGCDHW, GKZYXC, NGKDHW);
else if constexpr(layouts_match
.template operator()<ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>())
return make_layouts(NGCDHW, GKCZYX, NGKDHW);
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable
}
}
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
}
}
// ----------------------------------------------------------------------------
// Data Types
// ----------------------------------------------------------------------------
/// @brief Helper function to report unsupported data type with a clear error message.
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
template <typename ADataType>
[[noreturn]] consteval void report_unsupported_data_type_error()
{
throw "Unsupported data type detected!\n"
"The ADataType is not recognized.\n"
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
"(BF8), "
"int8_t (I8), ck::Tuple<int8_t, int8_t> (I8_I8), uint8_t (U8).\n"
"Please verify that your kernel instance uses a supported data type.";
}
/// @brief Derives the data type from a device kernel Instance type.
/// @tparam Instance The device kernel instance type.
/// @return A builder::DataType enum value representing the input data type.
/// @details This function examines the Instance's ADataType to determine the data type
/// used for the input tensor. The function supports various floating-point and integer
/// types, including tuple types for mixed-precision operations.
///
/// Supported data types include:
/// - FP16 (ck::half_t)
/// - FP16_FP16 (ck::Tuple<ck::half_t, ck::half_t>)
/// - BF16 (ck::bhalf_t)
/// - BF16_BF16 (ck::Tuple<ck::bhalf_t, ck::bhalf_t>)
/// - FP32 (float)
/// - FP32_FP32 (ck::Tuple<float, float>)
/// - FP64 (double)
/// - FP8 (ck::f8_t)
/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t)
/// - I8 (int8_t)
/// - I8_I8 (ck::Tuple<int8_t, int8_t>)
/// - U8 (uint8_t)
template <typename Instance>
constexpr builder::DataType conv_data_type()
{
using InstTraits = InstanceTraits<Instance>;
using ADataType = typename InstTraits::ADataType;
using enum builder::DataType;
if constexpr(std::is_same_v<ADataType, ck::half_t>)
return FP16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
return FP16_FP16;
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
return BF16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
return BF16_BF16;
else if constexpr(std::is_same_v<ADataType, float>)
return FP32;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
return FP32_FP32;
else if constexpr(std::is_same_v<ADataType, double>)
return FP64;
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
return FP8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, int8_t>)
return I8;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
return I8_I8;
else if constexpr(std::is_same_v<ADataType, uint8_t>)
return U8;
else
{
report_unsupported_data_type_error<ADataType>();
return FP32; // Unreachable
}
}
// ----------------------------------------------------------------------------
// Elementwise Operations
// ----------------------------------------------------------------------------
/// @brief Helper function to report unsupported elementwise operation with a clear error message.
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
template <typename ElementwiseOp>
[[noreturn]] consteval void report_unsupported_elementwise_op_error()
{
throw "Unsupported elementwise operation detected!\n"
"The elementwise operation type is not recognized.\n"
"Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, "
"BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, "
"ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, "
"UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, "
"Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, "
"UnaryConvert.\n"
"Please verify that your kernel instance uses a supported elementwise operation.";
}
/// @brief Derives the elementwise operation from an operation functor type.
/// @tparam ElementwiseOp Elementwise operation functor type.
/// @return A builder::ElementwiseOperation enum value corresponding to the operation.
/// @details This function uses the operation's type name to determine which elementwise
/// operation is being used. The comparison is case-insensitive.
///
/// Supported operations include:
/// - Activation operations: Relu, Sigmoid, Tanh, Gelu, Silu, Elu, Swish, etc.
/// - Scaling operations: Scale, ScaleAdd, ConvScale, ConvScaleAdd, etc.
/// - Clamping operations: Clamp, AddClamp, etc.
/// - Combined operations: Add_Activation_Mul_Clamp, etc.
/// - Utility operations: PassThrough, UnaryConvert, etc.
///
/// TODO: Consider changing this to direct checks on the types, not strings.
template <typename ElementwiseOp>
constexpr builder::ElementwiseOperation elementwise_op()
{
using enum builder::ElementwiseOperation;
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
if constexpr(detail::case_insensitive_equal(name, "AddClamp"))
return ADD_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd"))
return ADD_RELU_ADD;
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
return BIAS_BNORM_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
return BILINEAR;
else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp"))
return BIAS_BNORM_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
return CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale"))
return CONV_INVSCALE;
else if constexpr(detail::case_insensitive_equal(name, "ConvScale"))
return CONV_SCALE;
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd"))
return CONV_SCALE_ADD;
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu"))
return CONV_SCALE_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
return SCALE;
else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd"))
return SCALE_ADD;
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
return PASS_THROUGH;
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
return SCALEADD_SCALEADD_RELU;
else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp"))
return DYNAMIC_UNARY_OP;
else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp"))
return UNARY_COMBINED_OP;
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp"))
return ACTIVATION_MUL2_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp"))
return ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp"))
return ADD_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp"))
return ADD_ACTIVATION_MUL2_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp"))
return ADD_MUL_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp"))
return ADD_MUL2_ACTIVATION_MUL_CLAMP;
else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert"))
return UNARY_CONVERT;
else if constexpr(detail::case_insensitive_equal(name, "Logistic"))
return LOGISTIC;
else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu"))
return CLIPPED_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Swish"))
return SWISH;
else if constexpr(detail::case_insensitive_equal(name, "Elu"))
return ELU;
else if constexpr(detail::case_insensitive_equal(name, "Power"))
return POWER;
else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu"))
return LEAKY_RELU;
else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs"))
return UNARY_ABS;
else if constexpr(detail::case_insensitive_equal(name, "Relu"))
return RELU;
else if constexpr(detail::case_insensitive_equal(name, "SoftRelu"))
return SOFT_RELU;
else if constexpr(detail::case_insensitive_equal(name, "Sigmoid"))
return SIGMOID;
else if constexpr(detail::case_insensitive_equal(name, "TanH"))
return TANH;
else if constexpr(detail::case_insensitive_equal(name, "Gelu"))
return GELU;
else if constexpr(detail::case_insensitive_equal(name, "Silu"))
return SILU;
else
{
report_unsupported_elementwise_op_error<ElementwiseOp>();
return PASS_THROUGH; // Unreachable
}
}
// ----------------------------------------------------------------------------
// GEMM Padding
// ----------------------------------------------------------------------------
/// @brief Derives the GEMM padding specification from a kernel instance type.
/// @tparam Instance A device kernel instance type.
/// @return A builder::GemmPadding enum value corresponding to the kernel's padding configuration.
/// @details This function extracts the GEMM specialization from the Instance's InstanceTraits
/// and converts it to the builder framework's GemmPadding enum. The padding specification
/// indicates which dimensions (M, N, K, O) are padded to handle non-aligned tensor sizes.
///
/// Supported padding configurations include:
/// - DEFAULT: No padding
/// - M_PADDING, N_PADDING, K_PADDING, O_PADDING: Single dimension padding
/// - MN_PADDING, MK_PADDING, NK_PADDING, etc.: Two dimension padding
/// - MNK_PADDING, MNO_PADDING, etc.: Three dimension padding
/// - MNKO_PADDING: All dimensions padded
template <typename Instance>
constexpr builder::GemmPadding gemm_spec()
{
using InstTraits = InstanceTraits<Instance>;
using enum builder::GemmPadding;
using enum ck::tensor_operation::device::GemmSpecialization;
constexpr auto spec = InstTraits::kGemmSpecialization;
switch(spec)
{
case Default: return DEFAULT;
case MPadding: return M_PADDING;
case NPadding: return N_PADDING;
case KPadding: return K_PADDING;
case MNPadding: return MN_PADDING;
case MKPadding: return MK_PADDING;
case NKPadding: return NK_PADDING;
case MNKPadding: return MNK_PADDING;
case OPadding: return O_PADDING;
case MOPadding: return MO_PADDING;
case NOPadding: return NO_PADDING;
case KOPadding: return KO_PADDING;
case MNOPadding: return MNO_PADDING;
case MKOPadding: return MKO_PADDING;
case NKOPadding: return NKO_PADDING;
case MNKOPadding: return MNKO_PADDING;
}
}
// ============================================================================
// SECTION 3: PIPELINE CONFIGURATION HELPERS
// ============================================================================
/// @brief Safely extracts the pipeline version from InstanceTraits.
/// @tparam InstTraits The InstanceTraits type to extract pipeline version from.
/// @return The pipeline version as a builder::PipelineVersion enum value.
/// @details This helper function checks if the InstanceTraits has a kPipelineVersion
/// field and extracts it if present. If not present, it returns a default value (V1).
/// This is necessary because not all convolution types expose pipeline version information.
template <typename InstTraits>
constexpr builder::PipelineVersion get_pipeline_version()
{
if constexpr(requires { InstTraits::kPipelineVersion; })
{
return convert_pipeline_version<InstTraits::kPipelineVersion>();
}
else
{
return builder::PipelineVersion::V1;
}
}
/// @brief Safely extracts the pipeline scheduler from InstanceTraits.
/// @tparam InstTraits The InstanceTraits type to extract pipeline scheduler from.
/// @return The pipeline scheduler as a builder::PipelineScheduler enum value.
/// @details This helper function checks if the InstanceTraits has a kPipelineScheduler
/// or kLoopScheduler field and extracts it if present. If neither is present, it returns
/// a default value (DEFAULT). This is necessary because different convolution types may
/// expose scheduler information through different field names.
template <typename InstTraits>
constexpr builder::PipelineScheduler get_pipeline_scheduler()
{
if constexpr(requires { InstTraits::kPipelineScheduler; })
{
return convert_pipeline_scheduler<InstTraits::kPipelineScheduler>();
}
else if constexpr(requires { InstTraits::kLoopScheduler; })
{
return convert_pipeline_scheduler<InstTraits::kLoopScheduler>();
}
else
{
return builder::PipelineScheduler::DEFAULT;
}
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,8 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"

View File

@@ -74,6 +74,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
namespace ck_tile::reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag
{
};
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <ck::index_t NDimSpatial,
typename ALayout_,
@@ -175,6 +180,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
LoopSched,
NumGroupsToMerge>>
{
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag;
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;

View File

@@ -78,6 +78,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
namespace ck_tile::reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 device kernel
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag
{
};
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <ck::index_t NDimSpatial,
typename ALayout_,
@@ -179,6 +184,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
BComputeDataType_,
DirectLoad>>
{
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag;
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;

View File

@@ -73,6 +73,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
namespace ck_tile::reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor device kernel
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag
{
};
// Specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
template <ck::index_t NDimSpatial,
typename ALayout_,
@@ -173,6 +178,9 @@ struct InstanceTraits<
BComputeDataType_,
LoopSched>>
{
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag;
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;

View File

@@ -108,7 +108,8 @@ target_link_libraries(test_ckb_reference_execution PRIVATE utility)
# Tests convolution trait selection and configuration
add_ck_builder_test(test_ckb_conv_traits
conv/ck/test_conv_traits.cpp
conv/ck/unit_instance_to_conv_traits.cpp)
conv/ck/unit_instance_to_conv_traits_features.cpp
conv/ck/unit_instance_to_conv_traits_instances.cpp)
# Tests convolution problem description and parameter handling
add_ck_builder_test(test_ckb_conv_description

View File

@@ -6,7 +6,7 @@
#include <concepts>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_to_conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
@@ -86,72 +86,72 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
ck::half_t, // BComputeDataType
false>; // DirectLoad
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Use ConvTraitsTmpl to extract compile-time information
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
EXPECT_THAT(Traits::layout,
EXPECT_EQ(traits.spatial_dim, 2);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
EXPECT_EQ(traits.thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
EXPECT_EQ(traits.tile_dims.m, 128);
EXPECT_EQ(traits.tile_dims.n, 128);
EXPECT_EQ(traits.tile_dims.k, 16);
// Verify A tile transfer info
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding);
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
// Verify B tile transfer info
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding);
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
// Verify warp GEMM params
EXPECT_EQ(Traits::warp_gemm.gemm_m, 32);
EXPECT_EQ(Traits::warp_gemm.gemm_n, 32);
EXPECT_EQ(Traits::warp_gemm.m_iter, 4);
EXPECT_EQ(Traits::warp_gemm.n_iter, 4);
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
// Verify output tile transfer info
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8);
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
// Verify pipeline configuration
EXPECT_EQ(Traits::pipeline_scheduler, PipelineScheduler::INTRAWAVE);
EXPECT_EQ(Traits::pipeline_version, PipelineVersion::V1);
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE);
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
@@ -214,30 +214,30 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
ck::LoopScheduler::Default, // LoopSched
1>; // NumGroupsToMerge
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Use ConvTraitsTmpl to extract compile-time information
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
EXPECT_THAT(Traits::layout,
EXPECT_EQ(traits.spatial_dim, 2);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
EXPECT_EQ(traits.thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
EXPECT_EQ(traits.tile_dims.m, 128);
EXPECT_EQ(traits.tile_dims.n, 128);
EXPECT_EQ(traits.tile_dims.k, 16);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
@@ -298,29 +298,29 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default>; // LoopSched
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Use ConvTraitsTmpl to extract compile-time information
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
EXPECT_THAT(Traits::layout,
EXPECT_EQ(traits.spatial_dim, 2);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
EXPECT_EQ(traits.thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
EXPECT_EQ(traits.tile_dims.m, 128);
EXPECT_EQ(traits.tile_dims.n, 128);
EXPECT_EQ(traits.tile_dims.k, 16);
}
} // anonymous namespace

View File

@@ -0,0 +1,800 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// ============================================================================
// Unit Tests for Individual Conversion Functions
// ============================================================================
//
// PURPOSE:
// --------
// These tests verify individual conversion and extraction functions that
// transform raw CK kernel parameters into semantic types. Each test
// focuses on a single conversion function to ensure it correctly maps
// CK types to builder enums and structures.
//
// TEST COVERAGE:
// --------------
// 1. Enum Conversions:
// - Pipeline versions (BlockGemmPipelineVersion and PipelineVersion)
// - Pipeline schedulers (BlockGemmPipelineScheduler and LoopScheduler)
//
// 2. Elementwise Operations (14 operations):
// - PassThrough, Scale, Relu, Gelu, Sigmoid, Tanh, ScaleAdd
// - Silu, Swish, Elu, LeakyRelu, UnaryConvert, ConvScale, ConvScaleAdd
//
// 3. Convolution Properties:
// - Direction detection (Forward)
// - Specializations (Default, Filter1x1Pad0, Filter1x1Stride1Pad0,
// Filter3x3, OddC)
//
// 4. Layout Detection:
// - 1D layouts (GNWC, NWGC, NGCW)
// - 2D layouts (GNHWC, NHWGC, NGCHW with GKYXC/GKCYX)
// - 3D layouts (GNDHWC, NDHWGC, NGCDHW)
//
// 5. Data Type Detection:
// - FP16, BF16, FP32, I8
//
// 6. Pipeline Configuration:
// - Pipeline versions (V2, V3)
// - Schedulers (Interwave)
//
// 7. GEMM Padding Variations (17 types):
// - Default, MNK, M, N, K, MN, MK, NK
// - O, MO, NO, KO, MNO, MKO, NKO, MNKO
// ============================================================================
#include "ck/utility/scheduler_enum.hpp"
#include "ck_tile/builder/types.hpp"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck_tile/builder/reflect/instance_to_conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
namespace {
using ::ck_tile::builder::ConvDirection;
using ::ck_tile::builder::DataType;
using ::ck_tile::builder::ElementwiseOperation;
using ::ck_tile::builder::GemmPadding;
using ::ck_tile::builder::PipelineScheduler;
using ::ck_tile::builder::PipelineVersion;
using ::ck_tile::builder::TensorLayout;
using ::testing::ElementsAre;
// ============================================================================
// Test Helper Templates
// ============================================================================
// These templates reduce boilerplate by providing sensible defaults for
// template parameters that don't vary in most tests.
// ============================================================================
namespace defaults {
// Default values used across most tests
static constexpr int kBlockSize = 256;
static constexpr int kMPerBlock = 128;
static constexpr int kNPerBlock = 128;
static constexpr int kKPerBlock = 16;
static constexpr int kAK1 = 8;
static constexpr int kBK1 = 8;
static constexpr int kMPerXDL = 32;
static constexpr int kNPerXDL = 32;
static constexpr int kMXdlPerWave = 4;
static constexpr int kNXdlPerWave = 4;
static constexpr int kABlockTransferSrcVectorDim = 2;
static constexpr int kABlockTransferSrcScalarPerVector = 8;
static constexpr int kABlockTransferDstScalarPerVector_AK1 = 8;
static constexpr int kABlockLdsExtraM = 1;
static constexpr int kBBlockTransferSrcVectorDim = 2;
static constexpr int kBBlockTransferSrcScalarPerVector = 8;
static constexpr int kBBlockTransferDstScalarPerVector_BK1 = 8;
static constexpr int kBBlockLdsExtraN = 1;
static constexpr int kCShuffleMXdlPerWavePerShuffle = 1;
static constexpr int kCShuffleNXdlPerWavePerShuffle = 1;
static constexpr int kCDEBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr bool kDirectLoad = false;
using DefaultABlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>;
using DefaultABlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>;
using DefaultABlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>;
using DefaultBBlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>;
using DefaultBBlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>;
using DefaultBBlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>;
using DefaultCDEBlockTransferClusterLengths = ck::Sequence<1, 32, 1, 8>;
} // namespace defaults
// DeviceInstanceForTests - V3 variant with sensible defaults
template <int NDimSpatial = 2,
typename ALayout = ck::tensor_layout::convolution::GNHWC,
typename BLayout = ck::tensor_layout::convolution::GKYXC,
typename ELayout = ck::tensor_layout::convolution::GNHWK,
typename ADataType = ck::half_t,
typename BDataType = ck::half_t,
typename EDataType = ck::half_t,
typename AccDataType = float,
typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename CDEElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization GemmSpec =
ck::tensor_operation::device::GemmSpecialization::Default,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched =
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer = ck::BlockGemmPipelineVersion::v1>
using DeviceInstanceForTests_V3 =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
NDimSpatial,
ALayout,
BLayout,
ck::Tuple<>,
ELayout,
ADataType,
BDataType,
AccDataType,
ADataType,
ck::Tuple<>,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ConvForwardSpecialization,
GemmSpec,
defaults::kBlockSize,
defaults::kMPerBlock,
defaults::kNPerBlock,
defaults::kKPerBlock,
defaults::kAK1,
defaults::kBK1,
defaults::kMPerXDL,
defaults::kNPerXDL,
defaults::kMXdlPerWave,
defaults::kNXdlPerWave,
defaults::DefaultABlockTransferThreadClusterLengths,
defaults::DefaultABlockTransferThreadClusterArrangeOrder,
defaults::DefaultABlockTransferSrcAccessOrder,
defaults::kABlockTransferSrcVectorDim,
defaults::kABlockTransferSrcScalarPerVector,
defaults::kABlockTransferDstScalarPerVector_AK1,
defaults::kABlockLdsExtraM,
defaults::DefaultBBlockTransferThreadClusterLengths,
defaults::DefaultBBlockTransferThreadClusterArrangeOrder,
defaults::DefaultBBlockTransferSrcAccessOrder,
defaults::kBBlockTransferSrcVectorDim,
defaults::kBBlockTransferSrcScalarPerVector,
defaults::kBBlockTransferDstScalarPerVector_BK1,
defaults::kBBlockLdsExtraN,
defaults::kCShuffleMXdlPerWavePerShuffle,
defaults::kCShuffleNXdlPerWavePerShuffle,
defaults::DefaultCDEBlockTransferClusterLengths,
defaults::kCDEBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ADataType,
BDataType,
defaults::kDirectLoad>;
// Test case helper for specialization testing
template <ck::tensor_operation::device::ConvolutionForwardSpecialization Spec>
using SpecializationTestInstance =
DeviceInstanceForTests_V3<2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Spec>;
// Test case helper for layout testing (1D, 2D, 3D)
template <int NDim, typename ALayout, typename BLayout, typename ELayout>
using LayoutTestInstance = DeviceInstanceForTests_V3<NDim, ALayout, BLayout, ELayout>;
// Test case helper for data type testing
template <typename DataType, typename AccDataType = float>
using DataTypeTestInstance = DeviceInstanceForTests_V3<2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK,
DataType,
DataType,
DataType,
AccDataType>;
// Test case helper for pipeline version testing
template <ck::BlockGemmPipelineVersion PipelineVer>
using PipelineVersionTestInstance = DeviceInstanceForTests_V3<
2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization::Default,
ck::BlockGemmPipelineScheduler::Intrawave,
PipelineVer>;
// Test case helper for pipeline scheduler testing
template <ck::BlockGemmPipelineScheduler Scheduler>
using PipelineSchedulerTestInstance = DeviceInstanceForTests_V3<
2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization::Default,
Scheduler>;
// Test case helper for GEMM padding testing
template <ck::tensor_operation::device::GemmSpecialization GemmSpec>
using GemmPaddingTestInstance = DeviceInstanceForTests_V3<
2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
GemmSpec>;
// ============================================================================
// Test Enum Conversion Functions
// ============================================================================
TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineVersion)
{
using ck_tile::reflect::conv::convert_pipeline_version;
using enum ::ck::BlockGemmPipelineVersion;
using enum ::ck_tile::builder::PipelineVersion;
EXPECT_EQ(convert_pipeline_version<v1>(), V1);
EXPECT_EQ(convert_pipeline_version<v2>(), V2);
EXPECT_EQ(convert_pipeline_version<v3>(), V3);
EXPECT_EQ(convert_pipeline_version<v4>(), V4);
EXPECT_EQ(convert_pipeline_version<v5>(), V5);
}
TEST(InstanceToConvTraits, ConvertsPipelineVersion)
{
using ck_tile::reflect::conv::convert_pipeline_version;
using enum ck::PipelineVersion;
using enum PipelineVersion;
EXPECT_EQ(convert_pipeline_version<v1>(), V1);
EXPECT_EQ(convert_pipeline_version<v2>(), V2);
EXPECT_EQ(convert_pipeline_version<v4>(), V4);
EXPECT_EQ(convert_pipeline_version<weight_only>(), WEIGHT_ONLY);
}
TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineScheduler)
{
using ck_tile::reflect::conv::convert_pipeline_scheduler;
using enum ck::BlockGemmPipelineScheduler;
using enum PipelineScheduler;
EXPECT_EQ(convert_pipeline_scheduler<Intrawave>(), INTRAWAVE);
EXPECT_EQ(convert_pipeline_scheduler<Interwave>(), INTERWAVE);
}
TEST(InstanceToConvTraits, ConvertsLoopScheduler)
{
using ck_tile::reflect::conv::convert_pipeline_scheduler;
using enum ck::LoopScheduler;
using enum PipelineScheduler;
EXPECT_EQ(convert_pipeline_scheduler<Default>(), DEFAULT);
EXPECT_EQ(convert_pipeline_scheduler<Interwave>(), INTERWAVE);
}
// ============================================================================
// Test Elementwise Operations
// ============================================================================
TEST(InstanceToConvTraits, ExtractsPassThroughOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::PassThrough>();
EXPECT_EQ(op, PASS_THROUGH);
}
TEST(InstanceToConvTraits, ExtractsScaleOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Scale>();
EXPECT_EQ(op, SCALE);
}
TEST(InstanceToConvTraits, ExtractsReluOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Relu>();
EXPECT_EQ(op, RELU);
}
TEST(InstanceToConvTraits, ExtractsGeluOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Gelu>();
EXPECT_EQ(op, GELU);
}
TEST(InstanceToConvTraits, ExtractsSigmoidOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Sigmoid>();
EXPECT_EQ(op, SIGMOID);
}
TEST(InstanceToConvTraits, ExtractsTanhOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::TanH>();
EXPECT_EQ(op, TANH);
}
TEST(InstanceToConvTraits, ExtractsScaleAddOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::ScaleAdd>();
EXPECT_EQ(op, SCALE_ADD);
}
TEST(InstanceToConvTraits, ExtractsSiluOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Silu>();
EXPECT_EQ(op, SILU);
}
TEST(InstanceToConvTraits, ExtractsSwishOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Swish>();
EXPECT_EQ(op, SWISH);
}
TEST(InstanceToConvTraits, ExtractsEluOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Elu>();
EXPECT_EQ(op, ELU);
}
TEST(InstanceToConvTraits, ExtractsLeakyReluOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::LeakyRelu>();
EXPECT_EQ(op, LEAKY_RELU);
}
TEST(InstanceToConvTraits, ExtractsUnaryConvertOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::UnaryConvert>();
EXPECT_EQ(op, UNARY_CONVERT);
}
TEST(InstanceToConvTraits, ExtractsConvScaleOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::ConvScale>();
EXPECT_EQ(op, CONV_SCALE);
}
TEST(InstanceToConvTraits, ExtractsConvScaleAddOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::ConvScaleAdd>();
EXPECT_EQ(op, CONV_SCALE_ADD);
}
// ============================================================================
// Test Convolution Direction Detection
// ============================================================================
TEST(InstanceToConvTraits, DetectsForwardDirection)
{
using DeviceInstance = DeviceInstanceForTests_V3<>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
}
// ============================================================================
// Test Convolution Specialization Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsDefaultSpecialization)
{
using DeviceInstance = SpecializationTestInstance<
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
}
TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization)
{
using DeviceInstance = SpecializationTestInstance<
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0);
}
TEST(InstanceToConvTraits, ExtractsFilter1x1Stride1Pad0Specialization)
{
using DeviceInstance = SpecializationTestInstance<
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.conv_specialization,
ck_tile::builder::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0);
}
TEST(InstanceToConvTraits, ExtractsFilter3x3Specialization)
{
using DeviceInstance = SpecializationTestInstance<
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_3x3);
}
TEST(InstanceToConvTraits, ExtractsOddCSpecialization)
{
using DeviceInstance = SpecializationTestInstance<
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::ODD_C);
}
// ============================================================================
// Test 1D Convolution Layout Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsGnwcLayout)
{
using DeviceInstance = LayoutTestInstance<1,
ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GNWK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 1);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNWC, TensorLayout::GKXC, TensorLayout::GNWK));
}
TEST(InstanceToConvTraits, ExtractsNwgcLayout)
{
using DeviceInstance = LayoutTestInstance<1,
ck::tensor_layout::convolution::NWGC,
ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::NWGK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 1);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NWGC, TensorLayout::GKXC, TensorLayout::NWGK));
}
TEST(InstanceToConvTraits, ExtractsNgcwLayout)
{
using DeviceInstance = LayoutTestInstance<1,
ck::tensor_layout::convolution::NGCW,
ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::NGKW>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 1);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NGCW, TensorLayout::GKXC, TensorLayout::NGKW));
}
// ============================================================================
// Test 2D Convolution Layout Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsGnhwcLayout)
{
using DeviceInstance = LayoutTestInstance<2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
}
TEST(InstanceToConvTraits, ExtractsNhwgcLayout)
{
using DeviceInstance = LayoutTestInstance<2,
ck::tensor_layout::convolution::NHWGC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::NHWGK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NHWGC, TensorLayout::GKYXC, TensorLayout::NHWGK));
}
TEST(InstanceToConvTraits, ExtractsNgchwGkyxcLayout)
{
using DeviceInstance = LayoutTestInstance<2,
ck::tensor_layout::convolution::NGCHW,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::NGKHW>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NGCHW, TensorLayout::GKYXC, TensorLayout::NGKHW));
}
TEST(InstanceToConvTraits, ExtractsNgchwGkcyxLayout)
{
using DeviceInstance = LayoutTestInstance<2,
ck::tensor_layout::convolution::NGCHW,
ck::tensor_layout::convolution::GKCYX,
ck::tensor_layout::convolution::NGKHW>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NGCHW, TensorLayout::GKCYX, TensorLayout::NGKHW));
}
// ============================================================================
// Test 3D Convolution Layout Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsGndhwcLayout)
{
using DeviceInstance = LayoutTestInstance<3,
ck::tensor_layout::convolution::GNDHWC,
ck::tensor_layout::convolution::GKZYXC,
ck::tensor_layout::convolution::GNDHWK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 3);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNDHWC, TensorLayout::GKZYXC, TensorLayout::GNDHWK));
}
TEST(InstanceToConvTraits, ExtractsNdhwgcLayout)
{
using DeviceInstance = LayoutTestInstance<3,
ck::tensor_layout::convolution::NDHWGC,
ck::tensor_layout::convolution::GKZYXC,
ck::tensor_layout::convolution::NDHWGK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 3);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NDHWGC, TensorLayout::GKZYXC, TensorLayout::NDHWGK));
}
TEST(InstanceToConvTraits, ExtractsNgcdhwLayout)
{
using DeviceInstance = LayoutTestInstance<3,
ck::tensor_layout::convolution::NGCDHW,
ck::tensor_layout::convolution::GKZYXC,
ck::tensor_layout::convolution::NGKDHW>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 3);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NGCDHW, TensorLayout::GKZYXC, TensorLayout::NGKDHW));
}
// ============================================================================
// Test Data Type Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsFp16DataType)
{
using DeviceInstance = DataTypeTestInstance<ck::half_t>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.data_type, DataType::FP16);
}
TEST(InstanceToConvTraits, ExtractsBf16DataType)
{
using DeviceInstance = DataTypeTestInstance<ck::bhalf_t>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.data_type, DataType::BF16);
}
TEST(InstanceToConvTraits, ExtractsFp32DataType)
{
using DeviceInstance = DataTypeTestInstance<float, float>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.data_type, DataType::FP32);
}
TEST(InstanceToConvTraits, ExtractsI8DataType)
{
using DeviceInstance = DataTypeTestInstance<int8_t, int32_t>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.data_type, DataType::I8);
}
// ============================================================================
// Test Pipeline Version Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsPipelineV2)
{
using DeviceInstance = PipelineVersionTestInstance<ck::BlockGemmPipelineVersion::v2>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V2);
}
TEST(InstanceToConvTraits, ExtractsPipelineV3)
{
using DeviceInstance = PipelineVersionTestInstance<ck::BlockGemmPipelineVersion::v3>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V3);
}
TEST(InstanceToConvTraits, ExtractsInterwaveScheduler)
{
using DeviceInstance = PipelineSchedulerTestInstance<ck::BlockGemmPipelineScheduler::Interwave>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTERWAVE);
}
// ============================================================================
// Test GEMM Padding Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsDefaultGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::Default>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
}
TEST(InstanceToConvTraits, ExtractsMnkGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MNK_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::M_PADDING);
}
TEST(InstanceToConvTraits, ExtractsNPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::N_PADDING);
}
TEST(InstanceToConvTraits, ExtractsKPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::KPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::K_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMnPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MN_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMkPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MKPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MK_PADDING);
}
TEST(InstanceToConvTraits, ExtractsNkPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NKPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::NK_PADDING);
}
TEST(InstanceToConvTraits, ExtractsOPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::OPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::O_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsNoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::NO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsKoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::KOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::KO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMnoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MNO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMkoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MKOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MKO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsNkoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NKOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::NKO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMnkoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNKOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MNKO_PADDING);
}
} // anonymous namespace

View File

@@ -0,0 +1,262 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// ============================================================================
// Unit Tests for Complete Device Instance Transformations
// ============================================================================
//
// PURPOSE:
// --------
// These tests verify the complete instance_to_conv_traits transformation
// for entire Device class templates. Each test validates that all traits
// are correctly extracted from a specific Device class instantiation.
//
// TEST COVERAGE:
// --------------
// Complete transformation verification for each XDL Device class template:
// 1. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// 2. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// 3. DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
//
// Each test verifies:
// - Spatial dimension extraction
// - Convolution direction
// - Data type detection
// - GEMM padding configuration
// - Tile dimensions (M, N, K per block)
// - Pipeline scheduler and version
// ============================================================================
#include <gtest/gtest.h>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck_tile/builder/reflect/instance_to_conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
namespace {
using ::ck_tile::builder::ConvDirection;
using ::ck_tile::builder::DataType;
using ::ck_tile::builder::GemmPadding;
using ::ck_tile::builder::PipelineScheduler;
using ::ck_tile::builder::PipelineVersion;
// ============================================================================
// Comprehensive Transformation Tests - Per Device Class Template
// ============================================================================
// These tests verify the complete InstanceTraits → ConvTraits transformation
// for each forward convolution Device class template.
// ============================================================================
TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization::Default,
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
false>; // DirectLoad
using InstTraits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
// Verify tile dimensions
EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock);
EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock);
EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock);
// Verify pipeline configuration
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE);
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
}
TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffle)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization::Default,
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default, // LoopSched
1>; // NumGroupsToMerge
using InstTraits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
// Verify tile dimensions
EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock);
EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock);
EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock);
// Verify pipeline configuration (uses LoopScheduler instead of BlockGemmPipelineScheduler)
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
}
TEST(InstanceToConvTraits, TransformsFwdMultipleDXdlLargeTensor)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization::Default,
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default>; // LoopSched
using InstTraits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
// Verify tile dimensions
EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock);
EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock);
EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock);
// Verify pipeline configuration
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
}
} // anonymous namespace