mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_BUILDER] Convert convolution traits to a struct with factory functions (#3547)
* Factor helpers out of conv_traits.hpp * Create a non-templated conv_traits struct * Migrate to new instance-specific instance_to_conv_traits functions * Clean up reflection concepts * Clean up ConvTraits helpers * Update testing for convolution traits This is a lot of cleanup on tests to have verbose coverage of feature extraction, explicit tests for each supported device kernel, and simple, readable test code. * Address reviewer comments and resolve merge conflict
This commit is contained in:
@@ -7,43 +7,52 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_description.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_to_conv_traits.hpp"
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
/// @brief Factory function to create ConvDescription from a convolution instance type
|
||||
/// @tparam Instance The convolution instance type (must have ConvTraits)
|
||||
/// @return A ConvDescription object populated with the instance's configuration details
|
||||
template <conv::HasConvTraits Instance>
|
||||
/// @brief Concept to check if an Instance type has conv traits
|
||||
template <typename Instance>
|
||||
concept HasConvTraits = requires {
|
||||
{ conv::instance_to_conv_traits<Instance>() };
|
||||
};
|
||||
|
||||
/// Factory function to create ConvDescription from a convolution instance type
|
||||
/// Instance The convolution instance type
|
||||
/// A ConvDescription object populated with the instance's configuration details
|
||||
///
|
||||
/// TODO: Fix ConvDescription to just use the ConvTraits directly.
|
||||
template <typename Instance>
|
||||
requires HasConvTraits<Instance>
|
||||
conv::ConvDescription describe()
|
||||
{
|
||||
using Traits = conv::ConvTraits<Instance>;
|
||||
const auto traits = conv::instance_to_conv_traits<Instance>();
|
||||
|
||||
return conv::ConvDescription(
|
||||
conv::ConvSignatureInfo{
|
||||
.spatial_dim = Traits::spatial_dim,
|
||||
.direction = Traits::direction,
|
||||
.input_layout = Traits::layout[0],
|
||||
.weight_layout = Traits::layout[1],
|
||||
.output_layout = Traits::layout[2],
|
||||
.data_type = Traits::data_type,
|
||||
.input_element_op = Traits::input_element_op,
|
||||
.weight_element_op = Traits::weight_element_op,
|
||||
.output_element_op = Traits::output_element_op,
|
||||
.spatial_dim = traits.spatial_dim,
|
||||
.direction = traits.direction,
|
||||
.input_layout = traits.layout[0],
|
||||
.weight_layout = traits.layout[1],
|
||||
.output_layout = traits.layout[2],
|
||||
.data_type = traits.data_type,
|
||||
.input_element_op = traits.input_element_op,
|
||||
.weight_element_op = traits.weight_element_op,
|
||||
.output_element_op = traits.output_element_op,
|
||||
},
|
||||
conv::GemmAlgorithmInfo{
|
||||
.thread_block_size = Traits::thread_block_size,
|
||||
.tile_dims = Traits::tile_dims,
|
||||
.warp_gemm = Traits::warp_gemm,
|
||||
.a_tile_transfer = Traits::a_tile_transfer,
|
||||
.b_tile_transfer = Traits::b_tile_transfer,
|
||||
.c_tile_transfer = Traits::c_tile_transfer,
|
||||
.pipeline_version = Traits::pipeline_version,
|
||||
.pipeline_scheduler = Traits::pipeline_scheduler,
|
||||
.conv_specialization = Traits::conv_specialization,
|
||||
.padding = Traits::gemm_padding,
|
||||
.thread_block_size = traits.thread_block_size,
|
||||
.tile_dims = traits.tile_dims,
|
||||
.warp_gemm = traits.warp_gemm,
|
||||
.a_tile_transfer = traits.a_tile_transfer,
|
||||
.b_tile_transfer = traits.b_tile_transfer,
|
||||
.c_tile_transfer = traits.c_tile_transfer,
|
||||
.pipeline_version = traits.pipeline_version,
|
||||
.pipeline_scheduler = traits.pipeline_scheduler,
|
||||
.conv_specialization = traits.conv_specialization,
|
||||
.padding = traits.gemm_padding,
|
||||
},
|
||||
[]() { return reflect::instance_string<Instance>(); });
|
||||
[]<typename T = Instance>() { return reflect::instance_string<T>(); });
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
|
||||
@@ -1,664 +1,109 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// Runtime-accessible convolution kernel configuration data structure
|
||||
//
|
||||
// This file defines ConvTraits, a pure data structure that captures the complete
|
||||
// configuration of a convolution kernel in a domain-specific abstraction, without
|
||||
// requiring knowledge of the underlying kernel instance implementation details.
|
||||
//
|
||||
// ## Purpose and Design
|
||||
//
|
||||
// ConvTraits provides type erasure for convolution kernel configurations, allowing
|
||||
// for reflection of convolution kernel objects. The struct represents kernel
|
||||
// traits in terms of convolution-specific concepts for AMD GPUs rather than raw
|
||||
// template parameters.
|
||||
//
|
||||
// ## Architecture and Usage
|
||||
//
|
||||
// ConvTraits sits at the center of the reflection system:
|
||||
//
|
||||
// 1. **Population**: Values are created by `instance_to_conv_traits()` template
|
||||
// specializations that extract configuration from compile-time InstanceTraits
|
||||
//
|
||||
// 2. **Consumption**: Used by ConvDescription to provide human-readable descriptions
|
||||
// of kernel configurations for debugging, logging, and documentation
|
||||
//
|
||||
// ## Structure Organization
|
||||
//
|
||||
// The struct separates kernel configuration into two logical categories:
|
||||
//
|
||||
// - **Signature Information**: Defines what the kernel computes (direction, layouts,
|
||||
// data types, elementwise operations, specializations)
|
||||
//
|
||||
// - **Algorithm Information**: Defines how the kernel computes (thread block size,
|
||||
// tile dimensions, memory access patterns, pipeline configuration)
|
||||
//
|
||||
// ## Evolution and Extensibility
|
||||
//
|
||||
// ConvTraits is designed to evolve through composition (not inheritance):
|
||||
//
|
||||
// - Currently supports XDL forward convolution kernels
|
||||
// - Will extend to the other forward convolutions
|
||||
// - Will be extended to cover backward data and backward weight convolutions
|
||||
// - Will incorporate fusion operations and additional specializations
|
||||
// - Uses std::optional and std::variant for optional/variant fields
|
||||
// - Eventually will generalize to KernelTraits for GEMM, flash attention, etc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/pipeline_enum.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_types.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
// Forward convolution layout concept - checks for A/B/E layout types
|
||||
template <typename T>
|
||||
concept HasFwdConvLayouts = requires {
|
||||
typename T::ALayout;
|
||||
typename T::BLayout;
|
||||
typename T::ELayout;
|
||||
};
|
||||
|
||||
// GEMM specialization concept - checks for kGemmSpecialization member
|
||||
template <typename T>
|
||||
concept HasGemmSpec = requires {
|
||||
{
|
||||
T::kGemmSpecialization
|
||||
} -> std::convertible_to<ck::tensor_operation::device::GemmSpecialization>;
|
||||
};
|
||||
|
||||
// Data types concept - checks for ADataType member
|
||||
template <typename T>
|
||||
concept HasDataTypes = requires { typename T::ADataType; };
|
||||
|
||||
// Elementwise operations concept - checks for A/B/CDE elementwise operation types
|
||||
template <typename T>
|
||||
concept HasElementwiseOps = requires {
|
||||
typename T::AElementwiseOperation;
|
||||
typename T::BElementwiseOperation;
|
||||
typename T::CDEElementwiseOperation;
|
||||
};
|
||||
|
||||
// Tile parameters concept - checks for tile dimension and transfer members
|
||||
template <typename T>
|
||||
concept HasTileParams = requires {
|
||||
{ T::kKPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kMPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kNPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kAK1 } -> std::convertible_to<int>;
|
||||
{ T::kBK1 } -> std::convertible_to<int>;
|
||||
T::kCThreadClusterLengths;
|
||||
};
|
||||
|
||||
// Comprehensive concept that checks if an instance has all XDL forward convolution traits
|
||||
// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions
|
||||
template <typename T>
|
||||
concept IsXdlFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasDataTypes<T> &&
|
||||
HasElementwiseOps<T> && HasTileParams<T>;
|
||||
|
||||
// Primary concept for checking if a type can be described
|
||||
// Currently only forward convolutions are supported, but this can be extended
|
||||
// in the future to include backward data and backward weight convolutions
|
||||
template <typename T>
|
||||
concept HasConvTraits = IsXdlFwdConv<InstanceTraits<T>>;
|
||||
|
||||
// Helper metafunctions to convert from ck enums to builder enums
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5).
|
||||
/// @details This function maps CK's block GEMM pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. The pipeline version
|
||||
/// determines the strategy used for data movement and computation overlap in the
|
||||
/// GEMM kernel's main loop.
|
||||
template <ck::BlockGemmPipelineVersion ck_ver>
|
||||
constexpr auto convert_pipeline_version()
|
||||
// Runtime data structure representing a convolution kernel's complete configuration
|
||||
//
|
||||
// This pure data struct (no template parameters, no static members) provides
|
||||
// type erasure for convolution kernel configurations. It can hold the configuration
|
||||
// from any convolution kernel instance, enabling runtime storage, comparison, and
|
||||
// manipulation of kernel properties.
|
||||
//
|
||||
// The struct is populated by `instance_to_conv_traits()` template specializations
|
||||
// that extract compile-time configuration from InstanceTraits and convert it to
|
||||
// this standardized runtime representation.
|
||||
//
|
||||
// Members are organized into two categories:
|
||||
// - **Signature Information**: Defines the computational interface (what to compute)
|
||||
// - **Algorithm Information**: Defines the implementation strategy (how to compute)
|
||||
//
|
||||
// Note: This struct will evolve to support additional convolution variants and
|
||||
// eventually generalize to other kernel types through composition.
|
||||
//
|
||||
// There is a lot we still need to do:
|
||||
//
|
||||
// TODO: Generalize type support for all tensors and accumulator.
|
||||
// TODO: Describe all tensros.
|
||||
// TODO: Include the full generalization of the signature from the input schema.
|
||||
// TODO: Include the full generalization of the algorithm from the input schema.
|
||||
struct ConvTraits
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v3: return V3;
|
||||
case v4: return V4;
|
||||
case v5: return V5;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY).
|
||||
/// @details This function maps CK's general pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. Note that this overload
|
||||
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
|
||||
/// variant, including support for specialized weight-only pipelines.
|
||||
template <ck::PipelineVersion ck_ver>
|
||||
constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::PipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v4: return V4;
|
||||
case weight_only: return WEIGHT_ONLY;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE).
|
||||
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
|
||||
/// builder framework's standardized scheduler enum. The scheduler determines how work
|
||||
/// is distributed and synchronized within and across wavefronts during pipeline execution.
|
||||
/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates
|
||||
/// across multiple wavefronts.
|
||||
template <ck::BlockGemmPipelineScheduler ck_sched>
|
||||
constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Intrawave: return INTRAWAVE;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE).
|
||||
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
|
||||
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
|
||||
/// the main computational loop are scheduled across threads. DEFAULT uses the standard
|
||||
/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved
|
||||
/// performance in certain scenarios.
|
||||
template <ck::LoopScheduler ck_sched>
|
||||
constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
// Helper metafunctions to derive signature information from Instance types
|
||||
|
||||
/// @brief Helper function to report unsupported convolution direction with a clear error message.
|
||||
template <typename Instance>
|
||||
[[noreturn]] consteval void report_unsupported_conv_direction_error()
|
||||
{
|
||||
throw "Unsupported convolution direction detected!\n"
|
||||
"The kernel instance does not have a recognized convolution specialization.\n"
|
||||
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
|
||||
"kConvBwdWeightSpecialization.\n"
|
||||
"Please verify that your kernel instance is properly configured.";
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution direction from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
|
||||
template <typename Instance>
|
||||
constexpr builder::ConvDirection conv_direction()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
|
||||
return builder::ConvDirection::FORWARD;
|
||||
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
|
||||
return builder::ConvDirection::BACKWARD_DATA;
|
||||
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
|
||||
return builder::ConvDirection::BACKWARD_WEIGHT;
|
||||
else
|
||||
{
|
||||
report_unsupported_conv_direction_error<Instance>();
|
||||
return builder::ConvDirection::FORWARD; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::ConvSpecialization` enum value.
|
||||
template <typename Instance>
|
||||
constexpr auto conv_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::ConvSpecialization;
|
||||
|
||||
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(InstTraits::kConvForwardSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter3x3: return FILTER_3x3;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
switch(InstTraits::kConvBwdDataSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
switch(InstTraits::kConvBwdWeightSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper variable template to check if CK layout enums match
|
||||
template <typename A,
|
||||
typename B,
|
||||
typename E,
|
||||
typename ExpectedA,
|
||||
typename ExpectedB,
|
||||
typename ExpectedE>
|
||||
inline constexpr bool layouts_are =
|
||||
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
|
||||
|
||||
/// @brief Helper function to report unsupported layout combinations with a clear error message.
|
||||
/// @details This consteval function is designed to fail at compile time with a descriptive
|
||||
/// error message when an unsupported layout combination is encountered.
|
||||
template <typename A, typename B, typename E, int SpatialDim>
|
||||
[[noreturn]] consteval void report_unsupported_layout_error()
|
||||
{
|
||||
// This will produce a compile-time error with the exception message
|
||||
throw "Unsupported convolution layout combination detected!\n"
|
||||
"The combination of ALayout, BLayout, and ELayout template parameters\n"
|
||||
"is not recognized for the given spatial dimension.\n"
|
||||
"Please verify that your convolution instance uses a supported layout configuration.\n"
|
||||
"Check the conv_layout() function for the list of supported layout combinations.";
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
/// index 0 -> Input layout
|
||||
/// index 1 -> Weight layout
|
||||
/// index 2 -> Output layout
|
||||
template <typename Instance>
|
||||
constexpr auto conv_layout()
|
||||
requires HasFwdConvLayouts<InstanceTraits<Instance>>
|
||||
{
|
||||
// Helper lambda to construct layout array
|
||||
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
|
||||
|
||||
using A = typename InstanceTraits<Instance>::ALayout;
|
||||
using B = typename InstanceTraits<Instance>::BLayout;
|
||||
using E = typename InstanceTraits<Instance>::ELayout;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using enum builder::TensorLayout;
|
||||
|
||||
switch(InstanceTraits<Instance>::kSpatialDim)
|
||||
{
|
||||
case 1:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
|
||||
return layouts(GNWC, GKXC, GNWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>)
|
||||
return layouts(GNWC, GKXC, GNWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
|
||||
return layouts(NWGC, GKXC, NWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
|
||||
return layouts(NGCW, GKXC, NGKW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKCX, ctl::NGKW>)
|
||||
return layouts(NGCW, GKCX, NGKW);
|
||||
break;
|
||||
case 2:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>)
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
|
||||
return layouts(NHWGC, GKYXC, NHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>)
|
||||
return layouts(NHWGC, GKYXC, NHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKYXC, NGKHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKCYX, NGKHW);
|
||||
break;
|
||||
case 3:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
|
||||
return layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>)
|
||||
return layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
|
||||
return layouts(NDHWGC, GKZYXC, NDHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKZYXC, NGKDHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKCZYX, NGKDHW);
|
||||
break;
|
||||
}
|
||||
|
||||
// If we reach here, the layout combination is not supported
|
||||
// Call consteval function to trigger a compile-time error with a clear message
|
||||
report_unsupported_layout_error<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
|
||||
|
||||
// This return is unreachable but needed to satisfy the compiler
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
}
|
||||
|
||||
/// @brief Helper function to report unsupported data type with a clear error message.
|
||||
template <typename ADataType>
|
||||
[[noreturn]] consteval void report_unsupported_data_type_error()
|
||||
{
|
||||
throw "Unsupported data type detected!\n"
|
||||
"The ADataType is not recognized.\n"
|
||||
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
|
||||
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
|
||||
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
|
||||
"(BF8), "
|
||||
"int8_t (I8), ck::Tuple<int8_t, int8_t> (I8_I8), uint8_t (U8).\n"
|
||||
"Please verify that your kernel instance uses a supported data type.";
|
||||
}
|
||||
|
||||
/// @brief Derives the data type from a device kernel `Instance` type.
|
||||
/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8).
|
||||
template <typename Instance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
requires HasDataTypes<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ADataType = typename InstTraits::ADataType;
|
||||
using enum builder::DataType;
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck::half_t>)
|
||||
return FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
|
||||
return FP16_FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
|
||||
return BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
|
||||
return BF16_BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, float>)
|
||||
return FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
|
||||
return FP32_FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, double>)
|
||||
return FP64;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
|
||||
return FP8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, int8_t>)
|
||||
return I8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
|
||||
return I8_I8;
|
||||
else if constexpr(std::is_same_v<ADataType, uint8_t>)
|
||||
return U8;
|
||||
else
|
||||
{
|
||||
report_unsupported_data_type_error<ADataType>();
|
||||
return FP32; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Helper function to report unsupported elementwise operation with a clear error message.
|
||||
template <typename ElementwiseOp>
|
||||
[[noreturn]] consteval void report_unsupported_elementwise_op_error()
|
||||
{
|
||||
throw "Unsupported elementwise operation detected!\n"
|
||||
"The elementwise operation type is not recognized.\n"
|
||||
"Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, "
|
||||
"BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, "
|
||||
"ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, "
|
||||
"UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, "
|
||||
"Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, "
|
||||
"UnaryConvert.\n"
|
||||
"Please verify that your kernel instance uses a supported elementwise operation.";
|
||||
}
|
||||
|
||||
/// @brief Derives the elementwise operation from op type.
|
||||
/// @tparam ElementwiseOp Elementwise operation functor type.
|
||||
/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation.
|
||||
template <typename ElementwiseOp>
|
||||
constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
using enum builder::ElementwiseOperation;
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
|
||||
if constexpr(detail::case_insensitive_equal(name, "AddClamp"))
|
||||
return ADD_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd"))
|
||||
return ADD_RELU_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
return BIAS_BNORM_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
|
||||
return BILINEAR;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp"))
|
||||
return BIAS_BNORM_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
return CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale"))
|
||||
return CONV_INVSCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScale"))
|
||||
return CONV_SCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd"))
|
||||
return CONV_SCALE_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu"))
|
||||
return CONV_SCALE_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
return SCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd"))
|
||||
return SCALE_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
return PASS_THROUGH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
return SCALEADD_SCALEADD_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp"))
|
||||
return DYNAMIC_UNARY_OP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp"))
|
||||
return UNARY_COMBINED_OP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp"))
|
||||
return ACTIVATION_MUL2_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp"))
|
||||
return ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp"))
|
||||
return ADD_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp"))
|
||||
return ADD_ACTIVATION_MUL2_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp"))
|
||||
return ADD_MUL_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp"))
|
||||
return ADD_MUL2_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert"))
|
||||
return UNARY_CONVERT;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Logistic"))
|
||||
return LOGISTIC;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu"))
|
||||
return CLIPPED_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Swish"))
|
||||
return SWISH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Elu"))
|
||||
return ELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Power"))
|
||||
return POWER;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu"))
|
||||
return LEAKY_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs"))
|
||||
return UNARY_ABS;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Relu"))
|
||||
return RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "SoftRelu"))
|
||||
return SOFT_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Sigmoid"))
|
||||
return SIGMOID;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "TanH"))
|
||||
return TANH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Gelu"))
|
||||
return GELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Silu"))
|
||||
return SILU;
|
||||
else
|
||||
{
|
||||
report_unsupported_elementwise_op_error<ElementwiseOp>();
|
||||
return PASS_THROUGH; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives a gemm padding from a kernel instance type.
|
||||
/// @tparam Instance - A Device Kernel object type.
|
||||
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
|
||||
template <typename Instance>
|
||||
constexpr builder::GemmPadding gemm_spec()
|
||||
requires HasGemmSpec<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::GemmPadding;
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
|
||||
|
||||
switch(gemm_spec)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case MPadding: return M_PADDING;
|
||||
case NPadding: return N_PADDING;
|
||||
case KPadding: return K_PADDING;
|
||||
case MNPadding: return MN_PADDING;
|
||||
case MKPadding: return MK_PADDING;
|
||||
case NKPadding: return NK_PADDING;
|
||||
case MNKPadding: return MNK_PADDING;
|
||||
case OPadding: return O_PADDING;
|
||||
case MOPadding: return MO_PADDING;
|
||||
case NOPadding: return NO_PADDING;
|
||||
case KOPadding: return KO_PADDING;
|
||||
case MNOPadding: return MNO_PADDING;
|
||||
case MKOPadding: return MKO_PADDING;
|
||||
case NKOPadding: return NKO_PADDING;
|
||||
case MNKOPadding: return MNKO_PADDING;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Primary template for extracting convolution traits.
|
||||
/// @details This struct is the main entry point for reflecting on a convolution
|
||||
/// kernel's properties. It is specialized to handle different kinds of input types.
|
||||
template <typename T>
|
||||
struct ConvTraits;
|
||||
|
||||
/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`.
|
||||
/// @details This is the primary specialization used to extract a comprehensive
|
||||
/// set of traits directly from a fully-formed device kernel `Instance` type.
|
||||
/// It uses `InstanceTraits` to access the kernel's template parameters.
|
||||
template <HasInstanceTraits Instance>
|
||||
requires IsXdlFwdConv<InstanceTraits<Instance>>
|
||||
struct ConvTraits<Instance>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
// --- Signature Information ---
|
||||
/// @brief The number of spatial dimensions in the convolution (1, 2, or 3).
|
||||
static constexpr int spatial_dim = InstTraits::kSpatialDim;
|
||||
/// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight).
|
||||
static constexpr builder::ConvDirection direction = conv_direction<Instance>();
|
||||
/// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK).
|
||||
static constexpr auto layout = conv_layout<Instance>();
|
||||
/// @brief The primary data type used in the computation (e.g., FP16, FP32).
|
||||
static constexpr builder::DataType data_type = conv_data_type<Instance>();
|
||||
int spatial_dim;
|
||||
builder::ConvDirection direction;
|
||||
std::array<builder::TensorLayout, 3> layout; // [input, weight, output]
|
||||
builder::DataType data_type;
|
||||
|
||||
static constexpr builder::ElementwiseOperation input_element_op =
|
||||
elementwise_op<typename InstTraits::AElementwiseOperation>();
|
||||
static constexpr builder::ElementwiseOperation weight_element_op =
|
||||
elementwise_op<typename InstTraits::BElementwiseOperation>();
|
||||
static constexpr builder::ElementwiseOperation output_element_op =
|
||||
elementwise_op<typename InstTraits::CDEElementwiseOperation>();
|
||||
builder::ElementwiseOperation input_element_op;
|
||||
builder::ElementwiseOperation weight_element_op;
|
||||
builder::ElementwiseOperation output_element_op;
|
||||
|
||||
/// @brief The GEMM specialization used by the kernel - padding
|
||||
static constexpr auto gemm_padding = gemm_spec<Instance>();
|
||||
/// @brief The convolution-specific specialization (e.g., Default, 1x1).
|
||||
static constexpr auto conv_specialization = conv_spec<Instance>();
|
||||
builder::GemmPadding gemm_padding;
|
||||
builder::ConvSpecialization conv_specialization;
|
||||
|
||||
// --- Algorithm Information ---
|
||||
/// @brief The total number of threads in a thread block (workgroup).
|
||||
static constexpr int thread_block_size = InstTraits::kBlockSize;
|
||||
/// @brief The dimensions of the data tile processed by the thread block.
|
||||
static constexpr DataTileInfo tile_dims = {
|
||||
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
|
||||
int thread_block_size;
|
||||
DataTileInfo tile_dims;
|
||||
|
||||
/// @brief Configuration for the A-matrix (input) tile transfer.
|
||||
static constexpr InputTileTransferInfo a_tile_transfer = {
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
|
||||
InputTileTransferInfo a_tile_transfer;
|
||||
InputTileTransferInfo b_tile_transfer;
|
||||
|
||||
/// @brief Configuration for the B-matrix (weights) tile transfer.
|
||||
static constexpr InputTileTransferInfo b_tile_transfer = {
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
|
||||
WarpGemmParams warp_gemm;
|
||||
|
||||
/// @brief Parameters for the warp-level GEMM computation.
|
||||
static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave};
|
||||
OutputTileTransferInfo c_tile_transfer;
|
||||
|
||||
/// @brief Configuration for the C-matrix (output) tile transfer.
|
||||
static constexpr OutputTileTransferInfo c_tile_transfer = {
|
||||
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
|
||||
|
||||
/// @brief Helper to safely get the pipeline version.
|
||||
/// @details This is only available for some convolutions (e.g., forward).
|
||||
/// If not present in `InstanceTraits`, it returns a default value.
|
||||
template <typename T = InstTraits>
|
||||
static constexpr auto get_pipeline_version()
|
||||
{
|
||||
if constexpr(requires { T::kPipelineVersion; })
|
||||
{
|
||||
return convert_pipeline_version<T::kPipelineVersion>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a default or indicate not available
|
||||
return builder::PipelineVersion::V1;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief The block GEMM pipeline version used by the kernel.
|
||||
static constexpr auto pipeline_version = get_pipeline_version();
|
||||
|
||||
/// @brief Helper to safely get the pipeline scheduler.
|
||||
/// @details This is only available for some convolutions. If not present
|
||||
/// in `InstanceTraits`, it returns a default value.
|
||||
template <typename T = InstTraits>
|
||||
static constexpr auto get_pipeline_scheduler()
|
||||
{
|
||||
if constexpr(requires { T::kPipelineScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<T::kPipelineScheduler>();
|
||||
}
|
||||
else if constexpr(requires { T::kLoopScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<T::kLoopScheduler>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a default or indicate not available
|
||||
return builder::PipelineScheduler::DEFAULT;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief The pipeline scheduler used by the kernel.
|
||||
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
|
||||
builder::PipelineVersion pipeline_version;
|
||||
builder::PipelineScheduler pipeline_scheduler;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = {.m = InstTraits::kMPerBlock,
|
||||
.n = InstTraits::kNPerBlock,
|
||||
.k = InstTraits::kKPerBlock},
|
||||
.a_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
|
||||
.b_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
|
||||
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave},
|
||||
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = {.m = InstTraits::kMPerBlock,
|
||||
.n = InstTraits::kNPerBlock,
|
||||
.k = InstTraits::kKPerBlock},
|
||||
.a_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
|
||||
.b_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
|
||||
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave},
|
||||
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = {.m = InstTraits::kMPerBlock,
|
||||
.n = InstTraits::kNPerBlock,
|
||||
.k = InstTraits::kKPerBlock},
|
||||
.a_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
|
||||
.b_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
|
||||
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave},
|
||||
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,739 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/pipeline_enum.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_types.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
/// @file conv_traits_helpers.hpp
|
||||
/// @brief Helper utilities for extracting convolution traits from kernel instances
|
||||
///
|
||||
/// This file provides compile-time reflection utilities to extract configuration
|
||||
/// information from CK convolution kernel instances and convert them to the builder
|
||||
/// framework's standardized representation.
|
||||
///
|
||||
/// ## Organization
|
||||
///
|
||||
/// The file is organized into the following sections:
|
||||
///
|
||||
/// 1. **Enum Conversions**: Functions to convert CK enums to builder enums
|
||||
/// - Pipeline version conversions (BlockGemmPipelineVersion, PipelineVersion)
|
||||
/// - Pipeline scheduler conversions (BlockGemmPipelineScheduler, LoopScheduler)
|
||||
///
|
||||
/// 2. **Signature Derivation**: Functions to extract signature information from instances
|
||||
/// - Convolution direction (conv_direction)
|
||||
/// - Convolution specialization (conv_spec)
|
||||
/// - Tensor layouts (conv_layout)
|
||||
/// - Data types (conv_data_type)
|
||||
/// - Elementwise operations (elementwise_op)
|
||||
/// - GEMM padding (gemm_spec)
|
||||
///
|
||||
/// 3. **Pipeline Configuration Helpers**: Safe extraction of pipeline parameters
|
||||
/// - Pipeline version extraction (get_pipeline_version)
|
||||
/// - Pipeline scheduler extraction (get_pipeline_scheduler)
|
||||
///
|
||||
/// ## Error Handling Strategy
|
||||
///
|
||||
/// This file uses a specific error handling pattern for compile-time errors:
|
||||
/// - **consteval functions with throw**: Used for error reporting to ensure SFINAE doesn't
|
||||
/// silently ignore errors. The thrown string becomes part of the compiler error message,
|
||||
/// providing clear context to developers.
|
||||
/// - **DO NOT replace with static_assert**: static_assert is silently ignored during SFINAE,
|
||||
/// which would hide errors instead of reporting them clearly.
|
||||
///
|
||||
/// @example
|
||||
/// ```cpp
|
||||
/// using Instance =
|
||||
/// ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<...>;
|
||||
///
|
||||
/// // Extract convolution direction
|
||||
/// constexpr auto dir = conv_direction<Instance>();
|
||||
///
|
||||
/// // Extract data type
|
||||
/// constexpr auto dtype = conv_data_type<Instance>();
|
||||
///
|
||||
/// // Extract layout configuration
|
||||
/// constexpr auto layouts = conv_layout<Instance>();
|
||||
/// ```
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
// ============================================================================
|
||||
// SECTION 1: ENUM CONVERSIONS
|
||||
// ============================================================================
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value.
|
||||
/// @details This function maps CK's block GEMM pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. The pipeline version
|
||||
/// determines the strategy used for data movement and computation overlap in the
|
||||
/// GEMM kernel's main loop.
|
||||
///
|
||||
/// Supported mappings:
|
||||
/// - v1 -> V1
|
||||
/// - v2 -> V2
|
||||
/// - v3 -> V3
|
||||
/// - v4 -> V4
|
||||
/// - v5 -> V5
|
||||
template <ck::BlockGemmPipelineVersion ck_ver>
|
||||
constexpr builder::PipelineVersion convert_pipeline_version()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v3: return V3;
|
||||
case v4: return V4;
|
||||
case v5: return V5;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value.
|
||||
/// @details This function maps CK's general pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. Note that this overload
|
||||
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
|
||||
/// variant, including support for specialized weight-only pipelines.
|
||||
///
|
||||
/// Supported mappings:
|
||||
/// - v1 -> V1
|
||||
/// - v2 -> V2
|
||||
/// - v4 -> V4
|
||||
/// - weight_only -> WEIGHT_ONLY
|
||||
template <ck::PipelineVersion ck_ver>
|
||||
constexpr builder::PipelineVersion convert_pipeline_version()
|
||||
{
|
||||
using enum ck::PipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v4: return V4;
|
||||
case weight_only: return WEIGHT_ONLY;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value.
|
||||
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
|
||||
/// builder framework's standardized scheduler enum. The scheduler determines how work
|
||||
/// is distributed and synchronized within and across wavefronts during pipeline execution.
|
||||
///
|
||||
/// Supported mappings:
|
||||
/// - Intrawave -> INTRAWAVE: Scheduling within a single wavefront
|
||||
/// - Interwave -> INTERWAVE: Coordination across multiple wavefronts
|
||||
template <ck::BlockGemmPipelineScheduler ck_sched>
|
||||
constexpr builder::PipelineScheduler convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Intrawave: return INTRAWAVE;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value.
|
||||
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
|
||||
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
|
||||
/// the main computational loop are scheduled across threads.
|
||||
///
|
||||
/// Supported mappings:
|
||||
/// - Default -> DEFAULT: Standard scheduling strategy
|
||||
/// - Interwave -> INTERWAVE: Cross-wavefront coordination for improved performance
|
||||
template <ck::LoopScheduler ck_sched>
|
||||
constexpr builder::PipelineScheduler convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SECTION 2: SIGNATURE DERIVATION FUNCTIONS
|
||||
// ============================================================================
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Convolution Direction
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Helper function to report unsupported convolution direction with a clear error message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
template <typename Instance>
|
||||
[[noreturn]] consteval void report_unsupported_conv_direction_error()
|
||||
{
|
||||
throw "Unsupported convolution direction detected!\n"
|
||||
"The kernel instance does not have a recognized convolution specialization.\n"
|
||||
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
|
||||
"kConvBwdWeightSpecialization.\n"
|
||||
"Please verify that your kernel instance is properly configured.";
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution direction from a device kernel Instance type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A builder::ConvDirection enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
|
||||
/// @details This function inspects the Instance's InstanceTraits to determine which
|
||||
/// convolution specialization field is present, and returns the corresponding direction.
|
||||
///
|
||||
/// The function checks for the presence of:
|
||||
/// - kConvForwardSpecialization -> FORWARD
|
||||
/// - kConvBwdDataSpecialization -> BACKWARD_DATA
|
||||
/// - kConvBwdWeightSpecialization -> BACKWARD_WEIGHT
|
||||
///
|
||||
/// @note Compilation will fail with a clear error message if the instance does not
|
||||
/// have a recognized convolution specialization field.
|
||||
template <typename Instance>
|
||||
constexpr builder::ConvDirection conv_direction()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
|
||||
return builder::ConvDirection::FORWARD;
|
||||
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
|
||||
return builder::ConvDirection::BACKWARD_DATA;
|
||||
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
|
||||
return builder::ConvDirection::BACKWARD_WEIGHT;
|
||||
else
|
||||
{
|
||||
report_unsupported_conv_direction_error<Instance>();
|
||||
return builder::ConvDirection::FORWARD; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Convolution Specialization
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Helper function to report unsupported convolution specialization with a clear error
|
||||
/// message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
template <typename Instance>
|
||||
[[noreturn]] consteval void report_unsupported_conv_spec_error()
|
||||
{
|
||||
throw "Unsupported convolution specialization detected!\n"
|
||||
"The kernel instance does not have a recognized convolution specialization field.\n"
|
||||
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
|
||||
"kConvBwdWeightSpecialization.\n"
|
||||
"Please verify that your kernel instance is properly configured.";
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel Instance type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A builder::ConvSpecialization enum value.
|
||||
/// @details This function extracts the specialization enum from the Instance's InstanceTraits
|
||||
/// and converts it to the corresponding builder framework enum.
|
||||
///
|
||||
/// For forward convolutions, supported specializations include:
|
||||
/// - Default, Filter1x1Pad0, Filter1x1Stride1Pad0, Filter3x3, OddC
|
||||
///
|
||||
/// For backward data convolutions:
|
||||
/// - Default, Filter1x1Stride1Pad0
|
||||
///
|
||||
/// For backward weight convolutions:
|
||||
/// - Default, Filter1x1Stride1Pad0, Filter1x1Pad0, OddC
|
||||
template <typename Instance>
|
||||
constexpr builder::ConvSpecialization conv_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using enum builder::ConvSpecialization;
|
||||
|
||||
switch(InstTraits::kConvForwardSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter3x3: return FILTER_3x3;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
using enum builder::ConvSpecialization;
|
||||
|
||||
switch(InstTraits::kConvBwdDataSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
using enum builder::ConvSpecialization;
|
||||
|
||||
switch(InstTraits::kConvBwdWeightSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
report_unsupported_conv_spec_error<Instance>();
|
||||
return builder::ConvSpecialization::DEFAULT; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Tensor Layouts
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Helper function to report unsupported layout combinations with a clear error message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
template <typename A, typename B, typename E, int SpatialDim>
|
||||
[[noreturn]] consteval void report_unsupported_layout_error()
|
||||
{
|
||||
throw "Unsupported convolution layout combination detected!\n"
|
||||
"The combination of ALayout, BLayout, and ELayout template parameters\n"
|
||||
"is not recognized for the given spatial dimension.\n"
|
||||
"Please verify that your convolution instance uses a supported layout configuration.\n"
|
||||
"Check the conv_layout() function for the list of supported layout combinations.";
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel Instance type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array<builder::TensorLayout, 3> containing the layouts for:
|
||||
/// - [0] Input tensor layout
|
||||
/// - [1] Weight tensor layout
|
||||
/// - [2] Output tensor layout
|
||||
/// @details This function examines the Instance's ALayout, BLayout, and ELayout types
|
||||
/// along with the spatial dimension to determine the appropriate layout configuration.
|
||||
///
|
||||
/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions).
|
||||
/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants.
|
||||
///
|
||||
/// @note Compilation will fail with a clear error message if the layout combination
|
||||
/// is not supported for the given spatial dimension.
|
||||
///
|
||||
/// TODO: If we don't check for supported layouts, this function can be simplified.
|
||||
template <typename Instance>
|
||||
constexpr std::array<builder::TensorLayout, 3> conv_layout()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using A = typename InstTraits::ALayout;
|
||||
using B = typename InstTraits::BLayout;
|
||||
using E = typename InstTraits::ELayout;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using enum builder::TensorLayout;
|
||||
|
||||
// Helper to check if layouts match expected types
|
||||
constexpr auto layouts_match = []<typename ExpA, typename ExpB, typename ExpE>() {
|
||||
return std::is_same_v<A, ExpA> && std::is_same_v<B, ExpB> && std::is_same_v<E, ExpE>;
|
||||
};
|
||||
|
||||
// Helper to construct layout array
|
||||
constexpr auto make_layouts = [](auto in, auto weight, auto out) {
|
||||
return std::array<builder::TensorLayout, 3>{in, weight, out};
|
||||
};
|
||||
|
||||
constexpr int spatial_dim = InstTraits::kSpatialDim;
|
||||
|
||||
if constexpr(spatial_dim == 1)
|
||||
{
|
||||
if constexpr(layouts_match.template operator()<ctl::GNWC, ctl::GKXC, ctl::GNWK>())
|
||||
return make_layouts(GNWC, GKXC, GNWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>())
|
||||
return make_layouts(GNWC, GKXC, GNWK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NWGC, ctl::GKXC, ctl::NWGK>())
|
||||
return make_layouts(NWGC, GKXC, NWGK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKXC, ctl::NGKW>())
|
||||
return make_layouts(NGCW, GKXC, NGKW);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKCX, ctl::NGKW>())
|
||||
return make_layouts(NGCW, GKCX, NGKW);
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNWC, GKXC, GNWK); // Unreachable
|
||||
}
|
||||
}
|
||||
else if constexpr(spatial_dim == 2)
|
||||
{
|
||||
if constexpr(layouts_match.template operator()<ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>())
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>())
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>())
|
||||
return make_layouts(NHWGC, GKYXC, NHWGK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>())
|
||||
return make_layouts(NHWGC, GKYXC, NHWGK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>())
|
||||
return make_layouts(NGCHW, GKYXC, NGKHW);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>())
|
||||
return make_layouts(NGCHW, GKCYX, NGKHW);
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
|
||||
}
|
||||
}
|
||||
else if constexpr(spatial_dim == 3)
|
||||
{
|
||||
if constexpr(layouts_match.template operator()<ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>())
|
||||
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>())
|
||||
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>())
|
||||
return make_layouts(NDHWGC, GKZYXC, NDHWGK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>())
|
||||
return make_layouts(NGCDHW, GKZYXC, NGKDHW);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>())
|
||||
return make_layouts(NGCDHW, GKCZYX, NGKDHW);
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Data Types
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Helper function to report unsupported data type with a clear error message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
template <typename ADataType>
|
||||
[[noreturn]] consteval void report_unsupported_data_type_error()
|
||||
{
|
||||
throw "Unsupported data type detected!\n"
|
||||
"The ADataType is not recognized.\n"
|
||||
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
|
||||
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
|
||||
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
|
||||
"(BF8), "
|
||||
"int8_t (I8), ck::Tuple<int8_t, int8_t> (I8_I8), uint8_t (U8).\n"
|
||||
"Please verify that your kernel instance uses a supported data type.";
|
||||
}
|
||||
|
||||
/// @brief Derives the data type from a device kernel Instance type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A builder::DataType enum value representing the input data type.
|
||||
/// @details This function examines the Instance's ADataType to determine the data type
|
||||
/// used for the input tensor. The function supports various floating-point and integer
|
||||
/// types, including tuple types for mixed-precision operations.
|
||||
///
|
||||
/// Supported data types include:
|
||||
/// - FP16 (ck::half_t)
|
||||
/// - FP16_FP16 (ck::Tuple<ck::half_t, ck::half_t>)
|
||||
/// - BF16 (ck::bhalf_t)
|
||||
/// - BF16_BF16 (ck::Tuple<ck::bhalf_t, ck::bhalf_t>)
|
||||
/// - FP32 (float)
|
||||
/// - FP32_FP32 (ck::Tuple<float, float>)
|
||||
/// - FP64 (double)
|
||||
/// - FP8 (ck::f8_t)
|
||||
/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t)
|
||||
/// - I8 (int8_t)
|
||||
/// - I8_I8 (ck::Tuple<int8_t, int8_t>)
|
||||
/// - U8 (uint8_t)
|
||||
template <typename Instance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ADataType = typename InstTraits::ADataType;
|
||||
using enum builder::DataType;
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck::half_t>)
|
||||
return FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
|
||||
return FP16_FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
|
||||
return BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
|
||||
return BF16_BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, float>)
|
||||
return FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
|
||||
return FP32_FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, double>)
|
||||
return FP64;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
|
||||
return FP8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, int8_t>)
|
||||
return I8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
|
||||
return I8_I8;
|
||||
else if constexpr(std::is_same_v<ADataType, uint8_t>)
|
||||
return U8;
|
||||
else
|
||||
{
|
||||
report_unsupported_data_type_error<ADataType>();
|
||||
return FP32; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Elementwise Operations
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Helper function to report unsupported elementwise operation with a clear error message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
template <typename ElementwiseOp>
|
||||
[[noreturn]] consteval void report_unsupported_elementwise_op_error()
|
||||
{
|
||||
throw "Unsupported elementwise operation detected!\n"
|
||||
"The elementwise operation type is not recognized.\n"
|
||||
"Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, "
|
||||
"BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, "
|
||||
"ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, "
|
||||
"UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, "
|
||||
"Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, "
|
||||
"UnaryConvert.\n"
|
||||
"Please verify that your kernel instance uses a supported elementwise operation.";
|
||||
}
|
||||
|
||||
/// @brief Derives the elementwise operation from an operation functor type.
|
||||
/// @tparam ElementwiseOp Elementwise operation functor type.
|
||||
/// @return A builder::ElementwiseOperation enum value corresponding to the operation.
|
||||
/// @details This function uses the operation's type name to determine which elementwise
|
||||
/// operation is being used. The comparison is case-insensitive.
|
||||
///
|
||||
/// Supported operations include:
|
||||
/// - Activation operations: Relu, Sigmoid, Tanh, Gelu, Silu, Elu, Swish, etc.
|
||||
/// - Scaling operations: Scale, ScaleAdd, ConvScale, ConvScaleAdd, etc.
|
||||
/// - Clamping operations: Clamp, AddClamp, etc.
|
||||
/// - Combined operations: Add_Activation_Mul_Clamp, etc.
|
||||
/// - Utility operations: PassThrough, UnaryConvert, etc.
|
||||
///
|
||||
/// TODO: Consider changing this to direct checks on the types, not strings.
|
||||
template <typename ElementwiseOp>
|
||||
constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
using enum builder::ElementwiseOperation;
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
|
||||
if constexpr(detail::case_insensitive_equal(name, "AddClamp"))
|
||||
return ADD_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd"))
|
||||
return ADD_RELU_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
return BIAS_BNORM_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
|
||||
return BILINEAR;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp"))
|
||||
return BIAS_BNORM_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
return CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale"))
|
||||
return CONV_INVSCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScale"))
|
||||
return CONV_SCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd"))
|
||||
return CONV_SCALE_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu"))
|
||||
return CONV_SCALE_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
return SCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd"))
|
||||
return SCALE_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
return PASS_THROUGH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
return SCALEADD_SCALEADD_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp"))
|
||||
return DYNAMIC_UNARY_OP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp"))
|
||||
return UNARY_COMBINED_OP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp"))
|
||||
return ACTIVATION_MUL2_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp"))
|
||||
return ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp"))
|
||||
return ADD_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp"))
|
||||
return ADD_ACTIVATION_MUL2_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp"))
|
||||
return ADD_MUL_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp"))
|
||||
return ADD_MUL2_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert"))
|
||||
return UNARY_CONVERT;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Logistic"))
|
||||
return LOGISTIC;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu"))
|
||||
return CLIPPED_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Swish"))
|
||||
return SWISH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Elu"))
|
||||
return ELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Power"))
|
||||
return POWER;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu"))
|
||||
return LEAKY_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs"))
|
||||
return UNARY_ABS;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Relu"))
|
||||
return RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "SoftRelu"))
|
||||
return SOFT_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Sigmoid"))
|
||||
return SIGMOID;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "TanH"))
|
||||
return TANH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Gelu"))
|
||||
return GELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Silu"))
|
||||
return SILU;
|
||||
else
|
||||
{
|
||||
report_unsupported_elementwise_op_error<ElementwiseOp>();
|
||||
return PASS_THROUGH; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// GEMM Padding
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Derives the GEMM padding specification from a kernel instance type.
|
||||
/// @tparam Instance A device kernel instance type.
|
||||
/// @return A builder::GemmPadding enum value corresponding to the kernel's padding configuration.
|
||||
/// @details This function extracts the GEMM specialization from the Instance's InstanceTraits
|
||||
/// and converts it to the builder framework's GemmPadding enum. The padding specification
|
||||
/// indicates which dimensions (M, N, K, O) are padded to handle non-aligned tensor sizes.
|
||||
///
|
||||
/// Supported padding configurations include:
|
||||
/// - DEFAULT: No padding
|
||||
/// - M_PADDING, N_PADDING, K_PADDING, O_PADDING: Single dimension padding
|
||||
/// - MN_PADDING, MK_PADDING, NK_PADDING, etc.: Two dimension padding
|
||||
/// - MNK_PADDING, MNO_PADDING, etc.: Three dimension padding
|
||||
/// - MNKO_PADDING: All dimensions padded
|
||||
template <typename Instance>
|
||||
constexpr builder::GemmPadding gemm_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::GemmPadding;
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
constexpr auto spec = InstTraits::kGemmSpecialization;
|
||||
|
||||
switch(spec)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case MPadding: return M_PADDING;
|
||||
case NPadding: return N_PADDING;
|
||||
case KPadding: return K_PADDING;
|
||||
case MNPadding: return MN_PADDING;
|
||||
case MKPadding: return MK_PADDING;
|
||||
case NKPadding: return NK_PADDING;
|
||||
case MNKPadding: return MNK_PADDING;
|
||||
case OPadding: return O_PADDING;
|
||||
case MOPadding: return MO_PADDING;
|
||||
case NOPadding: return NO_PADDING;
|
||||
case KOPadding: return KO_PADDING;
|
||||
case MNOPadding: return MNO_PADDING;
|
||||
case MKOPadding: return MKO_PADDING;
|
||||
case NKOPadding: return NKO_PADDING;
|
||||
case MNKOPadding: return MNKO_PADDING;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SECTION 3: PIPELINE CONFIGURATION HELPERS
|
||||
// ============================================================================
|
||||
|
||||
/// @brief Safely extracts the pipeline version from InstanceTraits.
|
||||
/// @tparam InstTraits The InstanceTraits type to extract pipeline version from.
|
||||
/// @return The pipeline version as a builder::PipelineVersion enum value.
|
||||
/// @details This helper function checks if the InstanceTraits has a kPipelineVersion
|
||||
/// field and extracts it if present. If not present, it returns a default value (V1).
|
||||
/// This is necessary because not all convolution types expose pipeline version information.
|
||||
template <typename InstTraits>
|
||||
constexpr builder::PipelineVersion get_pipeline_version()
|
||||
{
|
||||
if constexpr(requires { InstTraits::kPipelineVersion; })
|
||||
{
|
||||
return convert_pipeline_version<InstTraits::kPipelineVersion>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return builder::PipelineVersion::V1;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Safely extracts the pipeline scheduler from InstanceTraits.
|
||||
/// @tparam InstTraits The InstanceTraits type to extract pipeline scheduler from.
|
||||
/// @return The pipeline scheduler as a builder::PipelineScheduler enum value.
|
||||
/// @details This helper function checks if the InstanceTraits has a kPipelineScheduler
|
||||
/// or kLoopScheduler field and extracts it if present. If neither is present, it returns
|
||||
/// a default value (DEFAULT). This is necessary because different convolution types may
|
||||
/// expose scheduler information through different field names.
|
||||
template <typename InstTraits>
|
||||
constexpr builder::PipelineScheduler get_pipeline_scheduler()
|
||||
{
|
||||
if constexpr(requires { InstTraits::kPipelineScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<InstTraits::kPipelineScheduler>();
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kLoopScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<InstTraits::kLoopScheduler>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return builder::PipelineScheduler::DEFAULT;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,8 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
@@ -74,6 +74,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag
|
||||
{
|
||||
};
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
@@ -175,6 +180,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
LoopSched,
|
||||
NumGroupsToMerge>>
|
||||
{
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag;
|
||||
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
|
||||
@@ -78,6 +78,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 device kernel
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag
|
||||
{
|
||||
};
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
@@ -179,6 +184,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
BComputeDataType_,
|
||||
DirectLoad>>
|
||||
{
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag;
|
||||
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
|
||||
@@ -73,6 +73,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor device kernel
|
||||
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag
|
||||
{
|
||||
};
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
@@ -173,6 +178,9 @@ struct InstanceTraits<
|
||||
BComputeDataType_,
|
||||
LoopSched>>
|
||||
{
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag;
|
||||
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
|
||||
@@ -108,7 +108,8 @@ target_link_libraries(test_ckb_reference_execution PRIVATE utility)
|
||||
# Tests convolution trait selection and configuration
|
||||
add_ck_builder_test(test_ckb_conv_traits
|
||||
conv/ck/test_conv_traits.cpp
|
||||
conv/ck/unit_instance_to_conv_traits.cpp)
|
||||
conv/ck/unit_instance_to_conv_traits_features.cpp
|
||||
conv/ck/unit_instance_to_conv_traits_instances.cpp)
|
||||
|
||||
# Tests convolution problem description and parameter handling
|
||||
add_ck_builder_test(test_ckb_conv_description
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <concepts>
|
||||
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_to_conv_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
@@ -86,72 +86,72 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
|
||||
ck::half_t, // BComputeDataType
|
||||
false>; // DirectLoad
|
||||
|
||||
// Use ConvTraits to extract compile-time information
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
EXPECT_EQ(traits.spatial_dim, 2);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.data_type, DataType::FP16);
|
||||
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
EXPECT_EQ(traits.thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(Traits::tile_dims.m, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.n, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.k, 16);
|
||||
EXPECT_EQ(traits.tile_dims.m, 128);
|
||||
EXPECT_EQ(traits.tile_dims.n, 128);
|
||||
EXPECT_EQ(traits.tile_dims.k, 16);
|
||||
|
||||
// Verify A tile transfer info
|
||||
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify B tile transfer info
|
||||
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify warp GEMM params
|
||||
EXPECT_EQ(Traits::warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(Traits::warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(Traits::warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(Traits::warp_gemm.n_iter, 4);
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
|
||||
|
||||
// Verify output tile transfer info
|
||||
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(Traits::pipeline_scheduler, PipelineScheduler::INTRAWAVE);
|
||||
EXPECT_EQ(Traits::pipeline_version, PipelineVersion::V1);
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
@@ -214,30 +214,30 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
|
||||
ck::LoopScheduler::Default, // LoopSched
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
// Use ConvTraits to extract compile-time information
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
EXPECT_EQ(traits.spatial_dim, 2);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.data_type, DataType::FP16);
|
||||
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
EXPECT_EQ(traits.thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(Traits::tile_dims.m, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.n, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.k, 16);
|
||||
EXPECT_EQ(traits.tile_dims.m, 128);
|
||||
EXPECT_EQ(traits.tile_dims.n, 128);
|
||||
EXPECT_EQ(traits.tile_dims.k, 16);
|
||||
}
|
||||
// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
|
||||
@@ -298,29 +298,29 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
|
||||
ck::half_t, // BComputeDataType
|
||||
ck::LoopScheduler::Default>; // LoopSched
|
||||
|
||||
// Use ConvTraits to extract compile-time information
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
EXPECT_EQ(traits.spatial_dim, 2);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.data_type, DataType::FP16);
|
||||
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
EXPECT_EQ(traits.thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(Traits::tile_dims.m, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.n, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.k, 16);
|
||||
EXPECT_EQ(traits.tile_dims.m, 128);
|
||||
EXPECT_EQ(traits.tile_dims.n, 128);
|
||||
EXPECT_EQ(traits.tile_dims.k, 16);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,800 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// ============================================================================
|
||||
// Unit Tests for Individual Conversion Functions
|
||||
// ============================================================================
|
||||
//
|
||||
// PURPOSE:
|
||||
// --------
|
||||
// These tests verify individual conversion and extraction functions that
|
||||
// transform raw CK kernel parameters into semantic types. Each test
|
||||
// focuses on a single conversion function to ensure it correctly maps
|
||||
// CK types to builder enums and structures.
|
||||
//
|
||||
// TEST COVERAGE:
|
||||
// --------------
|
||||
// 1. Enum Conversions:
|
||||
// - Pipeline versions (BlockGemmPipelineVersion and PipelineVersion)
|
||||
// - Pipeline schedulers (BlockGemmPipelineScheduler and LoopScheduler)
|
||||
//
|
||||
// 2. Elementwise Operations (14 operations):
|
||||
// - PassThrough, Scale, Relu, Gelu, Sigmoid, Tanh, ScaleAdd
|
||||
// - Silu, Swish, Elu, LeakyRelu, UnaryConvert, ConvScale, ConvScaleAdd
|
||||
//
|
||||
// 3. Convolution Properties:
|
||||
// - Direction detection (Forward)
|
||||
// - Specializations (Default, Filter1x1Pad0, Filter1x1Stride1Pad0,
|
||||
// Filter3x3, OddC)
|
||||
//
|
||||
// 4. Layout Detection:
|
||||
// - 1D layouts (GNWC, NWGC, NGCW)
|
||||
// - 2D layouts (GNHWC, NHWGC, NGCHW with GKYXC/GKCYX)
|
||||
// - 3D layouts (GNDHWC, NDHWGC, NGCDHW)
|
||||
//
|
||||
// 5. Data Type Detection:
|
||||
// - FP16, BF16, FP32, I8
|
||||
//
|
||||
// 6. Pipeline Configuration:
|
||||
// - Pipeline versions (V2, V3)
|
||||
// - Schedulers (Interwave)
|
||||
//
|
||||
// 7. GEMM Padding Variations (17 types):
|
||||
// - Default, MNK, M, N, K, MN, MK, NK
|
||||
// - O, MO, NO, KO, MNO, MKO, NKO, MNKO
|
||||
// ============================================================================
|
||||
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_to_conv_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using ::ck_tile::builder::ConvDirection;
|
||||
using ::ck_tile::builder::DataType;
|
||||
using ::ck_tile::builder::ElementwiseOperation;
|
||||
using ::ck_tile::builder::GemmPadding;
|
||||
using ::ck_tile::builder::PipelineScheduler;
|
||||
using ::ck_tile::builder::PipelineVersion;
|
||||
using ::ck_tile::builder::TensorLayout;
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
// ============================================================================
|
||||
// Test Helper Templates
|
||||
// ============================================================================
|
||||
// These templates reduce boilerplate by providing sensible defaults for
|
||||
// template parameters that don't vary in most tests.
|
||||
// ============================================================================
|
||||
|
||||
namespace defaults {
|
||||
// Default values used across most tests
|
||||
static constexpr int kBlockSize = 256;
|
||||
static constexpr int kMPerBlock = 128;
|
||||
static constexpr int kNPerBlock = 128;
|
||||
static constexpr int kKPerBlock = 16;
|
||||
static constexpr int kAK1 = 8;
|
||||
static constexpr int kBK1 = 8;
|
||||
static constexpr int kMPerXDL = 32;
|
||||
static constexpr int kNPerXDL = 32;
|
||||
static constexpr int kMXdlPerWave = 4;
|
||||
static constexpr int kNXdlPerWave = 4;
|
||||
static constexpr int kABlockTransferSrcVectorDim = 2;
|
||||
static constexpr int kABlockTransferSrcScalarPerVector = 8;
|
||||
static constexpr int kABlockTransferDstScalarPerVector_AK1 = 8;
|
||||
static constexpr int kABlockLdsExtraM = 1;
|
||||
static constexpr int kBBlockTransferSrcVectorDim = 2;
|
||||
static constexpr int kBBlockTransferSrcScalarPerVector = 8;
|
||||
static constexpr int kBBlockTransferDstScalarPerVector_BK1 = 8;
|
||||
static constexpr int kBBlockLdsExtraN = 1;
|
||||
static constexpr int kCShuffleMXdlPerWavePerShuffle = 1;
|
||||
static constexpr int kCShuffleNXdlPerWavePerShuffle = 1;
|
||||
static constexpr int kCDEBlockTransferScalarPerVector_NPerBlock = 8;
|
||||
static constexpr bool kDirectLoad = false;
|
||||
|
||||
using DefaultABlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>;
|
||||
using DefaultABlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>;
|
||||
using DefaultABlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>;
|
||||
using DefaultBBlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>;
|
||||
using DefaultBBlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>;
|
||||
using DefaultBBlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>;
|
||||
using DefaultCDEBlockTransferClusterLengths = ck::Sequence<1, 32, 1, 8>;
|
||||
} // namespace defaults
|
||||
|
||||
// DeviceInstanceForTests - V3 variant with sensible defaults
|
||||
template <int NDimSpatial = 2,
|
||||
typename ALayout = ck::tensor_layout::convolution::GNHWC,
|
||||
typename BLayout = ck::tensor_layout::convolution::GKYXC,
|
||||
typename ELayout = ck::tensor_layout::convolution::GNHWK,
|
||||
typename ADataType = ck::half_t,
|
||||
typename BDataType = ck::half_t,
|
||||
typename EDataType = ck::half_t,
|
||||
typename AccDataType = float,
|
||||
typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
|
||||
typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
|
||||
typename CDEElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec =
|
||||
ck::tensor_operation::device::GemmSpecialization::Default,
|
||||
ck::BlockGemmPipelineScheduler BlkGemmPipeSched =
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion BlkGemmPipelineVer = ck::BlockGemmPipelineVersion::v1>
|
||||
using DeviceInstanceForTests_V3 =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
NDimSpatial,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
ADataType,
|
||||
ck::Tuple<>,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ConvForwardSpecialization,
|
||||
GemmSpec,
|
||||
defaults::kBlockSize,
|
||||
defaults::kMPerBlock,
|
||||
defaults::kNPerBlock,
|
||||
defaults::kKPerBlock,
|
||||
defaults::kAK1,
|
||||
defaults::kBK1,
|
||||
defaults::kMPerXDL,
|
||||
defaults::kNPerXDL,
|
||||
defaults::kMXdlPerWave,
|
||||
defaults::kNXdlPerWave,
|
||||
defaults::DefaultABlockTransferThreadClusterLengths,
|
||||
defaults::DefaultABlockTransferThreadClusterArrangeOrder,
|
||||
defaults::DefaultABlockTransferSrcAccessOrder,
|
||||
defaults::kABlockTransferSrcVectorDim,
|
||||
defaults::kABlockTransferSrcScalarPerVector,
|
||||
defaults::kABlockTransferDstScalarPerVector_AK1,
|
||||
defaults::kABlockLdsExtraM,
|
||||
defaults::DefaultBBlockTransferThreadClusterLengths,
|
||||
defaults::DefaultBBlockTransferThreadClusterArrangeOrder,
|
||||
defaults::DefaultBBlockTransferSrcAccessOrder,
|
||||
defaults::kBBlockTransferSrcVectorDim,
|
||||
defaults::kBBlockTransferSrcScalarPerVector,
|
||||
defaults::kBBlockTransferDstScalarPerVector_BK1,
|
||||
defaults::kBBlockLdsExtraN,
|
||||
defaults::kCShuffleMXdlPerWavePerShuffle,
|
||||
defaults::kCShuffleNXdlPerWavePerShuffle,
|
||||
defaults::DefaultCDEBlockTransferClusterLengths,
|
||||
defaults::kCDEBlockTransferScalarPerVector_NPerBlock,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ADataType,
|
||||
BDataType,
|
||||
defaults::kDirectLoad>;
|
||||
|
||||
// Test case helper for specialization testing
|
||||
template <ck::tensor_operation::device::ConvolutionForwardSpecialization Spec>
|
||||
using SpecializationTestInstance =
|
||||
DeviceInstanceForTests_V3<2,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Spec>;
|
||||
|
||||
// Test case helper for layout testing (1D, 2D, 3D)
|
||||
template <int NDim, typename ALayout, typename BLayout, typename ELayout>
|
||||
using LayoutTestInstance = DeviceInstanceForTests_V3<NDim, ALayout, BLayout, ELayout>;
|
||||
|
||||
// Test case helper for data type testing
|
||||
template <typename DataType, typename AccDataType = float>
|
||||
using DataTypeTestInstance = DeviceInstanceForTests_V3<2,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
AccDataType>;
|
||||
|
||||
// Test case helper for pipeline version testing
|
||||
template <ck::BlockGemmPipelineVersion PipelineVer>
|
||||
using PipelineVersionTestInstance = DeviceInstanceForTests_V3<
|
||||
2,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
ck::tensor_operation::device::GemmSpecialization::Default,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
PipelineVer>;
|
||||
|
||||
// Test case helper for pipeline scheduler testing
|
||||
template <ck::BlockGemmPipelineScheduler Scheduler>
|
||||
using PipelineSchedulerTestInstance = DeviceInstanceForTests_V3<
|
||||
2,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
ck::tensor_operation::device::GemmSpecialization::Default,
|
||||
Scheduler>;
|
||||
|
||||
// Test case helper for GEMM padding testing
|
||||
template <ck::tensor_operation::device::GemmSpecialization GemmSpec>
|
||||
using GemmPaddingTestInstance = DeviceInstanceForTests_V3<
|
||||
2,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
GemmSpec>;
|
||||
|
||||
// ============================================================================
|
||||
// Test Enum Conversion Functions
|
||||
// ============================================================================
|
||||
|
||||
TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineVersion)
|
||||
{
|
||||
using ck_tile::reflect::conv::convert_pipeline_version;
|
||||
using enum ::ck::BlockGemmPipelineVersion;
|
||||
using enum ::ck_tile::builder::PipelineVersion;
|
||||
EXPECT_EQ(convert_pipeline_version<v1>(), V1);
|
||||
EXPECT_EQ(convert_pipeline_version<v2>(), V2);
|
||||
EXPECT_EQ(convert_pipeline_version<v3>(), V3);
|
||||
EXPECT_EQ(convert_pipeline_version<v4>(), V4);
|
||||
EXPECT_EQ(convert_pipeline_version<v5>(), V5);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ConvertsPipelineVersion)
|
||||
{
|
||||
using ck_tile::reflect::conv::convert_pipeline_version;
|
||||
using enum ck::PipelineVersion;
|
||||
using enum PipelineVersion;
|
||||
EXPECT_EQ(convert_pipeline_version<v1>(), V1);
|
||||
EXPECT_EQ(convert_pipeline_version<v2>(), V2);
|
||||
EXPECT_EQ(convert_pipeline_version<v4>(), V4);
|
||||
EXPECT_EQ(convert_pipeline_version<weight_only>(), WEIGHT_ONLY);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineScheduler)
|
||||
{
|
||||
using ck_tile::reflect::conv::convert_pipeline_scheduler;
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
using enum PipelineScheduler;
|
||||
EXPECT_EQ(convert_pipeline_scheduler<Intrawave>(), INTRAWAVE);
|
||||
EXPECT_EQ(convert_pipeline_scheduler<Interwave>(), INTERWAVE);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ConvertsLoopScheduler)
|
||||
{
|
||||
using ck_tile::reflect::conv::convert_pipeline_scheduler;
|
||||
using enum ck::LoopScheduler;
|
||||
using enum PipelineScheduler;
|
||||
EXPECT_EQ(convert_pipeline_scheduler<Default>(), DEFAULT);
|
||||
EXPECT_EQ(convert_pipeline_scheduler<Interwave>(), INTERWAVE);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test Elementwise Operations
|
||||
// ============================================================================
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsPassThroughOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::PassThrough>();
|
||||
EXPECT_EQ(op, PASS_THROUGH);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsScaleOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Scale>();
|
||||
EXPECT_EQ(op, SCALE);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsReluOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Relu>();
|
||||
EXPECT_EQ(op, RELU);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsGeluOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Gelu>();
|
||||
EXPECT_EQ(op, GELU);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsSigmoidOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Sigmoid>();
|
||||
EXPECT_EQ(op, SIGMOID);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsTanhOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::TanH>();
|
||||
EXPECT_EQ(op, TANH);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsScaleAddOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::ScaleAdd>();
|
||||
EXPECT_EQ(op, SCALE_ADD);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsSiluOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Silu>();
|
||||
EXPECT_EQ(op, SILU);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsSwishOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Swish>();
|
||||
EXPECT_EQ(op, SWISH);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsEluOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Elu>();
|
||||
EXPECT_EQ(op, ELU);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsLeakyReluOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::LeakyRelu>();
|
||||
EXPECT_EQ(op, LEAKY_RELU);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsUnaryConvertOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::UnaryConvert>();
|
||||
EXPECT_EQ(op, UNARY_CONVERT);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsConvScaleOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::ConvScale>();
|
||||
EXPECT_EQ(op, CONV_SCALE);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsConvScaleAddOperation)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
constexpr auto op =
|
||||
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::ConvScaleAdd>();
|
||||
EXPECT_EQ(op, CONV_SCALE_ADD);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test Convolution Direction Detection
|
||||
// ============================================================================
|
||||
|
||||
TEST(InstanceToConvTraits, DetectsForwardDirection)
|
||||
{
|
||||
using DeviceInstance = DeviceInstanceForTests_V3<>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test Convolution Specialization Detection
|
||||
// ============================================================================
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsDefaultSpecialization)
|
||||
{
|
||||
using DeviceInstance = SpecializationTestInstance<
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization)
|
||||
{
|
||||
using DeviceInstance = SpecializationTestInstance<
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsFilter1x1Stride1Pad0Specialization)
|
||||
{
|
||||
using DeviceInstance = SpecializationTestInstance<
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.conv_specialization,
|
||||
ck_tile::builder::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsFilter3x3Specialization)
|
||||
{
|
||||
using DeviceInstance = SpecializationTestInstance<
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_3x3);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsOddCSpecialization)
|
||||
{
|
||||
using DeviceInstance = SpecializationTestInstance<
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::ODD_C);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 1D Convolution Layout Detection
|
||||
// ============================================================================
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsGnwcLayout)
|
||||
{
|
||||
using DeviceInstance = LayoutTestInstance<1,
|
||||
ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GNWK>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.spatial_dim, 1);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNWC, TensorLayout::GKXC, TensorLayout::GNWK));
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsNwgcLayout)
|
||||
{
|
||||
using DeviceInstance = LayoutTestInstance<1,
|
||||
ck::tensor_layout::convolution::NWGC,
|
||||
ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::NWGK>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.spatial_dim, 1);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::NWGC, TensorLayout::GKXC, TensorLayout::NWGK));
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsNgcwLayout)
|
||||
{
|
||||
using DeviceInstance = LayoutTestInstance<1,
|
||||
ck::tensor_layout::convolution::NGCW,
|
||||
ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::NGKW>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.spatial_dim, 1);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::NGCW, TensorLayout::GKXC, TensorLayout::NGKW));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 2D Convolution Layout Detection
|
||||
// ============================================================================
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsGnhwcLayout)
|
||||
{
|
||||
using DeviceInstance = LayoutTestInstance<2,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GNHWK>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsNhwgcLayout)
|
||||
{
|
||||
using DeviceInstance = LayoutTestInstance<2,
|
||||
ck::tensor_layout::convolution::NHWGC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::NHWGK>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::NHWGC, TensorLayout::GKYXC, TensorLayout::NHWGK));
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsNgchwGkyxcLayout)
|
||||
{
|
||||
using DeviceInstance = LayoutTestInstance<2,
|
||||
ck::tensor_layout::convolution::NGCHW,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::NGKHW>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::NGCHW, TensorLayout::GKYXC, TensorLayout::NGKHW));
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsNgchwGkcyxLayout)
|
||||
{
|
||||
using DeviceInstance = LayoutTestInstance<2,
|
||||
ck::tensor_layout::convolution::NGCHW,
|
||||
ck::tensor_layout::convolution::GKCYX,
|
||||
ck::tensor_layout::convolution::NGKHW>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::NGCHW, TensorLayout::GKCYX, TensorLayout::NGKHW));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test 3D Convolution Layout Detection
|
||||
// ============================================================================
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsGndhwcLayout)
|
||||
{
|
||||
using DeviceInstance = LayoutTestInstance<3,
|
||||
ck::tensor_layout::convolution::GNDHWC,
|
||||
ck::tensor_layout::convolution::GKZYXC,
|
||||
ck::tensor_layout::convolution::GNDHWK>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.spatial_dim, 3);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNDHWC, TensorLayout::GKZYXC, TensorLayout::GNDHWK));
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsNdhwgcLayout)
|
||||
{
|
||||
using DeviceInstance = LayoutTestInstance<3,
|
||||
ck::tensor_layout::convolution::NDHWGC,
|
||||
ck::tensor_layout::convolution::GKZYXC,
|
||||
ck::tensor_layout::convolution::NDHWGK>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.spatial_dim, 3);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::NDHWGC, TensorLayout::GKZYXC, TensorLayout::NDHWGK));
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsNgcdhwLayout)
|
||||
{
|
||||
using DeviceInstance = LayoutTestInstance<3,
|
||||
ck::tensor_layout::convolution::NGCDHW,
|
||||
ck::tensor_layout::convolution::GKZYXC,
|
||||
ck::tensor_layout::convolution::NGKDHW>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.spatial_dim, 3);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::NGCDHW, TensorLayout::GKZYXC, TensorLayout::NGKDHW));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test Data Type Detection
|
||||
// ============================================================================
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsFp16DataType)
|
||||
{
|
||||
using DeviceInstance = DataTypeTestInstance<ck::half_t>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.data_type, DataType::FP16);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsBf16DataType)
|
||||
{
|
||||
using DeviceInstance = DataTypeTestInstance<ck::bhalf_t>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.data_type, DataType::BF16);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsFp32DataType)
|
||||
{
|
||||
using DeviceInstance = DataTypeTestInstance<float, float>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.data_type, DataType::FP32);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsI8DataType)
|
||||
{
|
||||
using DeviceInstance = DataTypeTestInstance<int8_t, int32_t>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.data_type, DataType::I8);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test Pipeline Version Detection
|
||||
// ============================================================================
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsPipelineV2)
|
||||
{
|
||||
using DeviceInstance = PipelineVersionTestInstance<ck::BlockGemmPipelineVersion::v2>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V2);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsPipelineV3)
|
||||
{
|
||||
using DeviceInstance = PipelineVersionTestInstance<ck::BlockGemmPipelineVersion::v3>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V3);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsInterwaveScheduler)
|
||||
{
|
||||
using DeviceInstance = PipelineSchedulerTestInstance<ck::BlockGemmPipelineScheduler::Interwave>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTERWAVE);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test GEMM Padding Detection
|
||||
// ============================================================================
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsDefaultGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::Default>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsMnkGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::MNK_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsMPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::M_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsNPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::N_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsKPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::KPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::K_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsMnPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::MN_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsMkPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MKPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::MK_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsNkPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NKPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::NK_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsOPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::OPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::O_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsMoPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MOPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::MO_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsNoPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NOPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::NO_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsKoPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::KOPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::KO_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsMnoPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNOPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::MNO_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsMkoPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MKOPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::MKO_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsNkoPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NKOPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::NKO_PADDING);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsMnkoPaddingGemmPadding)
|
||||
{
|
||||
using DeviceInstance =
|
||||
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNKOPadding>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::MNKO_PADDING);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
@@ -0,0 +1,262 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// ============================================================================
|
||||
// Unit Tests for Complete Device Instance Transformations
|
||||
// ============================================================================
|
||||
//
|
||||
// PURPOSE:
|
||||
// --------
|
||||
// These tests verify the complete instance_to_conv_traits transformation
|
||||
// for entire Device class templates. Each test validates that all traits
|
||||
// are correctly extracted from a specific Device class instantiation.
|
||||
//
|
||||
// TEST COVERAGE:
|
||||
// --------------
|
||||
// Complete transformation verification for each XDL Device class template:
|
||||
// 1. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
// 2. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// 3. DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
//
|
||||
// Each test verifies:
|
||||
// - Spatial dimension extraction
|
||||
// - Convolution direction
|
||||
// - Data type detection
|
||||
// - GEMM padding configuration
|
||||
// - Tile dimensions (M, N, K per block)
|
||||
// - Pipeline scheduler and version
|
||||
// ============================================================================
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_to_conv_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using ::ck_tile::builder::ConvDirection;
|
||||
using ::ck_tile::builder::DataType;
|
||||
using ::ck_tile::builder::GemmPadding;
|
||||
using ::ck_tile::builder::PipelineScheduler;
|
||||
using ::ck_tile::builder::PipelineVersion;
|
||||
|
||||
// ============================================================================
|
||||
// Comprehensive Transformation Tests - Per Device Class Template
|
||||
// ============================================================================
|
||||
// These tests verify the complete InstanceTraits → ConvTraits transformation
|
||||
// for each forward convolution Device class template.
|
||||
// ============================================================================
|
||||
|
||||
TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
ck::tensor_operation::device::GemmSpecialization::Default,
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
false>; // DirectLoad
|
||||
|
||||
using InstTraits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
// Verify signature information
|
||||
EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
|
||||
EXPECT_EQ(traits.data_type, DataType::FP16);
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock);
|
||||
EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock);
|
||||
EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock);
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffle)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
ck::tensor_operation::device::GemmSpecialization::Default,
|
||||
1, // NumGemmKPrefetchStage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
ck::LoopScheduler::Default, // LoopSched
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
using InstTraits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
// Verify signature information
|
||||
EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
|
||||
EXPECT_EQ(traits.data_type, DataType::FP16);
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock);
|
||||
EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock);
|
||||
EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock);
|
||||
// Verify pipeline configuration (uses LoopScheduler instead of BlockGemmPipelineScheduler)
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, TransformsFwdMultipleDXdlLargeTensor)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
ck::tensor_operation::device::GemmSpecialization::Default,
|
||||
1, // NumGemmKPrefetchStage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
ck::LoopScheduler::Default>; // LoopSched
|
||||
|
||||
using InstTraits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
// Verify signature information
|
||||
EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
|
||||
EXPECT_EQ(traits.data_type, DataType::FP16);
|
||||
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock);
|
||||
EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock);
|
||||
EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock);
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
Reference in New Issue
Block a user