mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_BUILDER] Convert convolution traits to a struct with factory functions (#3547)
* Factor helpers out of conv_traits.hpp * Create a non-templated conv_traits struct * Migrate to new instance-specific instance_to_conv_traits functions * Clean up reflection concepts * Clean up ConvTraits helpers * Update testing for convolution traits This is a lot of cleanup on tests to have verbose coverage of feature extraction, explicit tests for each supported device kernel, and simple, readable test code. * Address reviewer comments and resolve merge conflict
This commit is contained in:
@@ -7,43 +7,52 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_description.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_to_conv_traits.hpp"
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
/// @brief Factory function to create ConvDescription from a convolution instance type
|
||||
/// @tparam Instance The convolution instance type (must have ConvTraits)
|
||||
/// @return A ConvDescription object populated with the instance's configuration details
|
||||
template <conv::HasConvTraits Instance>
|
||||
/// @brief Concept to check if an Instance type has conv traits
|
||||
template <typename Instance>
|
||||
concept HasConvTraits = requires {
|
||||
{ conv::instance_to_conv_traits<Instance>() };
|
||||
};
|
||||
|
||||
/// Factory function to create ConvDescription from a convolution instance type
|
||||
/// Instance The convolution instance type
|
||||
/// A ConvDescription object populated with the instance's configuration details
|
||||
///
|
||||
/// TODO: Fix ConvDescription to just use the ConvTraits directly.
|
||||
template <typename Instance>
|
||||
requires HasConvTraits<Instance>
|
||||
conv::ConvDescription describe()
|
||||
{
|
||||
using Traits = conv::ConvTraits<Instance>;
|
||||
const auto traits = conv::instance_to_conv_traits<Instance>();
|
||||
|
||||
return conv::ConvDescription(
|
||||
conv::ConvSignatureInfo{
|
||||
.spatial_dim = Traits::spatial_dim,
|
||||
.direction = Traits::direction,
|
||||
.input_layout = Traits::layout[0],
|
||||
.weight_layout = Traits::layout[1],
|
||||
.output_layout = Traits::layout[2],
|
||||
.data_type = Traits::data_type,
|
||||
.input_element_op = Traits::input_element_op,
|
||||
.weight_element_op = Traits::weight_element_op,
|
||||
.output_element_op = Traits::output_element_op,
|
||||
.spatial_dim = traits.spatial_dim,
|
||||
.direction = traits.direction,
|
||||
.input_layout = traits.layout[0],
|
||||
.weight_layout = traits.layout[1],
|
||||
.output_layout = traits.layout[2],
|
||||
.data_type = traits.data_type,
|
||||
.input_element_op = traits.input_element_op,
|
||||
.weight_element_op = traits.weight_element_op,
|
||||
.output_element_op = traits.output_element_op,
|
||||
},
|
||||
conv::GemmAlgorithmInfo{
|
||||
.thread_block_size = Traits::thread_block_size,
|
||||
.tile_dims = Traits::tile_dims,
|
||||
.warp_gemm = Traits::warp_gemm,
|
||||
.a_tile_transfer = Traits::a_tile_transfer,
|
||||
.b_tile_transfer = Traits::b_tile_transfer,
|
||||
.c_tile_transfer = Traits::c_tile_transfer,
|
||||
.pipeline_version = Traits::pipeline_version,
|
||||
.pipeline_scheduler = Traits::pipeline_scheduler,
|
||||
.conv_specialization = Traits::conv_specialization,
|
||||
.padding = Traits::gemm_padding,
|
||||
.thread_block_size = traits.thread_block_size,
|
||||
.tile_dims = traits.tile_dims,
|
||||
.warp_gemm = traits.warp_gemm,
|
||||
.a_tile_transfer = traits.a_tile_transfer,
|
||||
.b_tile_transfer = traits.b_tile_transfer,
|
||||
.c_tile_transfer = traits.c_tile_transfer,
|
||||
.pipeline_version = traits.pipeline_version,
|
||||
.pipeline_scheduler = traits.pipeline_scheduler,
|
||||
.conv_specialization = traits.conv_specialization,
|
||||
.padding = traits.gemm_padding,
|
||||
},
|
||||
[]() { return reflect::instance_string<Instance>(); });
|
||||
[]<typename T = Instance>() { return reflect::instance_string<T>(); });
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
|
||||
@@ -1,664 +1,109 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// Runtime-accessible convolution kernel configuration data structure
|
||||
//
|
||||
// This file defines ConvTraits, a pure data structure that captures the complete
|
||||
// configuration of a convolution kernel in a domain-specific abstraction, without
|
||||
// requiring knowledge of the underlying kernel instance implementation details.
|
||||
//
|
||||
// ## Purpose and Design
|
||||
//
|
||||
// ConvTraits provides type erasure for convolution kernel configurations, allowing
|
||||
// for reflection of convolution kernel objects. The struct represents kernel
|
||||
// traits in terms of convolution-specific concepts for AMD GPUs rather than raw
|
||||
// template parameters.
|
||||
//
|
||||
// ## Architecture and Usage
|
||||
//
|
||||
// ConvTraits sits at the center of the reflection system:
|
||||
//
|
||||
// 1. **Population**: Values are created by `instance_to_conv_traits()` template
|
||||
// specializations that extract configuration from compile-time InstanceTraits
|
||||
//
|
||||
// 2. **Consumption**: Used by ConvDescription to provide human-readable descriptions
|
||||
// of kernel configurations for debugging, logging, and documentation
|
||||
//
|
||||
// ## Structure Organization
|
||||
//
|
||||
// The struct separates kernel configuration into two logical categories:
|
||||
//
|
||||
// - **Signature Information**: Defines what the kernel computes (direction, layouts,
|
||||
// data types, elementwise operations, specializations)
|
||||
//
|
||||
// - **Algorithm Information**: Defines how the kernel computes (thread block size,
|
||||
// tile dimensions, memory access patterns, pipeline configuration)
|
||||
//
|
||||
// ## Evolution and Extensibility
|
||||
//
|
||||
// ConvTraits is designed to evolve through composition (not inheritance):
|
||||
//
|
||||
// - Currently supports XDL forward convolution kernels
|
||||
// - Will extend to the other forward convolutions
|
||||
// - Will be extended to cover backward data and backward weight convolutions
|
||||
// - Will incorporate fusion operations and additional specializations
|
||||
// - Uses std::optional and std::variant for optional/variant fields
|
||||
// - Eventually will generalize to KernelTraits for GEMM, flash attention, etc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/pipeline_enum.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_types.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
// Forward convolution layout concept - checks for A/B/E layout types
|
||||
template <typename T>
|
||||
concept HasFwdConvLayouts = requires {
|
||||
typename T::ALayout;
|
||||
typename T::BLayout;
|
||||
typename T::ELayout;
|
||||
};
|
||||
|
||||
// GEMM specialization concept - checks for kGemmSpecialization member
|
||||
template <typename T>
|
||||
concept HasGemmSpec = requires {
|
||||
{
|
||||
T::kGemmSpecialization
|
||||
} -> std::convertible_to<ck::tensor_operation::device::GemmSpecialization>;
|
||||
};
|
||||
|
||||
// Data types concept - checks for ADataType member
|
||||
template <typename T>
|
||||
concept HasDataTypes = requires { typename T::ADataType; };
|
||||
|
||||
// Elementwise operations concept - checks for A/B/CDE elementwise operation types
|
||||
template <typename T>
|
||||
concept HasElementwiseOps = requires {
|
||||
typename T::AElementwiseOperation;
|
||||
typename T::BElementwiseOperation;
|
||||
typename T::CDEElementwiseOperation;
|
||||
};
|
||||
|
||||
// Tile parameters concept - checks for tile dimension and transfer members
|
||||
template <typename T>
|
||||
concept HasTileParams = requires {
|
||||
{ T::kKPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kMPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kNPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kAK1 } -> std::convertible_to<int>;
|
||||
{ T::kBK1 } -> std::convertible_to<int>;
|
||||
T::kCThreadClusterLengths;
|
||||
};
|
||||
|
||||
// Comprehensive concept that checks if an instance has all XDL forward convolution traits
|
||||
// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions
|
||||
template <typename T>
|
||||
concept IsXdlFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasDataTypes<T> &&
|
||||
HasElementwiseOps<T> && HasTileParams<T>;
|
||||
|
||||
// Primary concept for checking if a type can be described
|
||||
// Currently only forward convolutions are supported, but this can be extended
|
||||
// in the future to include backward data and backward weight convolutions
|
||||
template <typename T>
|
||||
concept HasConvTraits = IsXdlFwdConv<InstanceTraits<T>>;
|
||||
|
||||
// Helper metafunctions to convert from ck enums to builder enums
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5).
|
||||
/// @details This function maps CK's block GEMM pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. The pipeline version
|
||||
/// determines the strategy used for data movement and computation overlap in the
|
||||
/// GEMM kernel's main loop.
|
||||
template <ck::BlockGemmPipelineVersion ck_ver>
|
||||
constexpr auto convert_pipeline_version()
|
||||
// Runtime data structure representing a convolution kernel's complete configuration
|
||||
//
|
||||
// This pure data struct (no template parameters, no static members) provides
|
||||
// type erasure for convolution kernel configurations. It can hold the configuration
|
||||
// from any convolution kernel instance, enabling runtime storage, comparison, and
|
||||
// manipulation of kernel properties.
|
||||
//
|
||||
// The struct is populated by `instance_to_conv_traits()` template specializations
|
||||
// that extract compile-time configuration from InstanceTraits and convert it to
|
||||
// this standardized runtime representation.
|
||||
//
|
||||
// Members are organized into two categories:
|
||||
// - **Signature Information**: Defines the computational interface (what to compute)
|
||||
// - **Algorithm Information**: Defines the implementation strategy (how to compute)
|
||||
//
|
||||
// Note: This struct will evolve to support additional convolution variants and
|
||||
// eventually generalize to other kernel types through composition.
|
||||
//
|
||||
// There is a lot we still need to do:
|
||||
//
|
||||
// TODO: Generalize type support for all tensors and accumulator.
|
||||
// TODO: Describe all tensros.
|
||||
// TODO: Include the full generalization of the signature from the input schema.
|
||||
// TODO: Include the full generalization of the algorithm from the input schema.
|
||||
struct ConvTraits
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v3: return V3;
|
||||
case v4: return V4;
|
||||
case v5: return V5;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY).
|
||||
/// @details This function maps CK's general pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. Note that this overload
|
||||
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
|
||||
/// variant, including support for specialized weight-only pipelines.
|
||||
template <ck::PipelineVersion ck_ver>
|
||||
constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::PipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v4: return V4;
|
||||
case weight_only: return WEIGHT_ONLY;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE).
|
||||
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
|
||||
/// builder framework's standardized scheduler enum. The scheduler determines how work
|
||||
/// is distributed and synchronized within and across wavefronts during pipeline execution.
|
||||
/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates
|
||||
/// across multiple wavefronts.
|
||||
template <ck::BlockGemmPipelineScheduler ck_sched>
|
||||
constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Intrawave: return INTRAWAVE;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE).
|
||||
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
|
||||
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
|
||||
/// the main computational loop are scheduled across threads. DEFAULT uses the standard
|
||||
/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved
|
||||
/// performance in certain scenarios.
|
||||
template <ck::LoopScheduler ck_sched>
|
||||
constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
// Helper metafunctions to derive signature information from Instance types
|
||||
|
||||
/// @brief Helper function to report unsupported convolution direction with a clear error message.
|
||||
template <typename Instance>
|
||||
[[noreturn]] consteval void report_unsupported_conv_direction_error()
|
||||
{
|
||||
throw "Unsupported convolution direction detected!\n"
|
||||
"The kernel instance does not have a recognized convolution specialization.\n"
|
||||
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
|
||||
"kConvBwdWeightSpecialization.\n"
|
||||
"Please verify that your kernel instance is properly configured.";
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution direction from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
|
||||
template <typename Instance>
|
||||
constexpr builder::ConvDirection conv_direction()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
|
||||
return builder::ConvDirection::FORWARD;
|
||||
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
|
||||
return builder::ConvDirection::BACKWARD_DATA;
|
||||
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
|
||||
return builder::ConvDirection::BACKWARD_WEIGHT;
|
||||
else
|
||||
{
|
||||
report_unsupported_conv_direction_error<Instance>();
|
||||
return builder::ConvDirection::FORWARD; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::ConvSpecialization` enum value.
|
||||
template <typename Instance>
|
||||
constexpr auto conv_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::ConvSpecialization;
|
||||
|
||||
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(InstTraits::kConvForwardSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter3x3: return FILTER_3x3;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
switch(InstTraits::kConvBwdDataSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
switch(InstTraits::kConvBwdWeightSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper variable template to check if CK layout enums match
|
||||
template <typename A,
|
||||
typename B,
|
||||
typename E,
|
||||
typename ExpectedA,
|
||||
typename ExpectedB,
|
||||
typename ExpectedE>
|
||||
inline constexpr bool layouts_are =
|
||||
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
|
||||
|
||||
/// @brief Helper function to report unsupported layout combinations with a clear error message.
|
||||
/// @details This consteval function is designed to fail at compile time with a descriptive
|
||||
/// error message when an unsupported layout combination is encountered.
|
||||
template <typename A, typename B, typename E, int SpatialDim>
|
||||
[[noreturn]] consteval void report_unsupported_layout_error()
|
||||
{
|
||||
// This will produce a compile-time error with the exception message
|
||||
throw "Unsupported convolution layout combination detected!\n"
|
||||
"The combination of ALayout, BLayout, and ELayout template parameters\n"
|
||||
"is not recognized for the given spatial dimension.\n"
|
||||
"Please verify that your convolution instance uses a supported layout configuration.\n"
|
||||
"Check the conv_layout() function for the list of supported layout combinations.";
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
/// index 0 -> Input layout
|
||||
/// index 1 -> Weight layout
|
||||
/// index 2 -> Output layout
|
||||
template <typename Instance>
|
||||
constexpr auto conv_layout()
|
||||
requires HasFwdConvLayouts<InstanceTraits<Instance>>
|
||||
{
|
||||
// Helper lambda to construct layout array
|
||||
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
|
||||
|
||||
using A = typename InstanceTraits<Instance>::ALayout;
|
||||
using B = typename InstanceTraits<Instance>::BLayout;
|
||||
using E = typename InstanceTraits<Instance>::ELayout;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using enum builder::TensorLayout;
|
||||
|
||||
switch(InstanceTraits<Instance>::kSpatialDim)
|
||||
{
|
||||
case 1:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
|
||||
return layouts(GNWC, GKXC, GNWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>)
|
||||
return layouts(GNWC, GKXC, GNWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
|
||||
return layouts(NWGC, GKXC, NWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
|
||||
return layouts(NGCW, GKXC, NGKW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKCX, ctl::NGKW>)
|
||||
return layouts(NGCW, GKCX, NGKW);
|
||||
break;
|
||||
case 2:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>)
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
|
||||
return layouts(NHWGC, GKYXC, NHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>)
|
||||
return layouts(NHWGC, GKYXC, NHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKYXC, NGKHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKCYX, NGKHW);
|
||||
break;
|
||||
case 3:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
|
||||
return layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>)
|
||||
return layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
|
||||
return layouts(NDHWGC, GKZYXC, NDHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKZYXC, NGKDHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKCZYX, NGKDHW);
|
||||
break;
|
||||
}
|
||||
|
||||
// If we reach here, the layout combination is not supported
|
||||
// Call consteval function to trigger a compile-time error with a clear message
|
||||
report_unsupported_layout_error<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
|
||||
|
||||
// This return is unreachable but needed to satisfy the compiler
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
}
|
||||
|
||||
/// @brief Helper function to report unsupported data type with a clear error message.
|
||||
template <typename ADataType>
|
||||
[[noreturn]] consteval void report_unsupported_data_type_error()
|
||||
{
|
||||
throw "Unsupported data type detected!\n"
|
||||
"The ADataType is not recognized.\n"
|
||||
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
|
||||
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
|
||||
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
|
||||
"(BF8), "
|
||||
"int8_t (I8), ck::Tuple<int8_t, int8_t> (I8_I8), uint8_t (U8).\n"
|
||||
"Please verify that your kernel instance uses a supported data type.";
|
||||
}
|
||||
|
||||
/// @brief Derives the data type from a device kernel `Instance` type.
|
||||
/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8).
|
||||
template <typename Instance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
requires HasDataTypes<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ADataType = typename InstTraits::ADataType;
|
||||
using enum builder::DataType;
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck::half_t>)
|
||||
return FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
|
||||
return FP16_FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
|
||||
return BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
|
||||
return BF16_BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, float>)
|
||||
return FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
|
||||
return FP32_FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, double>)
|
||||
return FP64;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
|
||||
return FP8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, int8_t>)
|
||||
return I8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
|
||||
return I8_I8;
|
||||
else if constexpr(std::is_same_v<ADataType, uint8_t>)
|
||||
return U8;
|
||||
else
|
||||
{
|
||||
report_unsupported_data_type_error<ADataType>();
|
||||
return FP32; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Helper function to report unsupported elementwise operation with a clear error message.
|
||||
template <typename ElementwiseOp>
|
||||
[[noreturn]] consteval void report_unsupported_elementwise_op_error()
|
||||
{
|
||||
throw "Unsupported elementwise operation detected!\n"
|
||||
"The elementwise operation type is not recognized.\n"
|
||||
"Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, "
|
||||
"BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, "
|
||||
"ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, "
|
||||
"UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, "
|
||||
"Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, "
|
||||
"UnaryConvert.\n"
|
||||
"Please verify that your kernel instance uses a supported elementwise operation.";
|
||||
}
|
||||
|
||||
/// @brief Derives the elementwise operation from op type.
|
||||
/// @tparam ElementwiseOp Elementwise operation functor type.
|
||||
/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation.
|
||||
template <typename ElementwiseOp>
|
||||
constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
using enum builder::ElementwiseOperation;
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
|
||||
if constexpr(detail::case_insensitive_equal(name, "AddClamp"))
|
||||
return ADD_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd"))
|
||||
return ADD_RELU_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
return BIAS_BNORM_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
|
||||
return BILINEAR;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp"))
|
||||
return BIAS_BNORM_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
return CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale"))
|
||||
return CONV_INVSCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScale"))
|
||||
return CONV_SCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd"))
|
||||
return CONV_SCALE_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu"))
|
||||
return CONV_SCALE_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
return SCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd"))
|
||||
return SCALE_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
return PASS_THROUGH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
return SCALEADD_SCALEADD_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp"))
|
||||
return DYNAMIC_UNARY_OP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp"))
|
||||
return UNARY_COMBINED_OP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp"))
|
||||
return ACTIVATION_MUL2_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp"))
|
||||
return ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp"))
|
||||
return ADD_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp"))
|
||||
return ADD_ACTIVATION_MUL2_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp"))
|
||||
return ADD_MUL_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp"))
|
||||
return ADD_MUL2_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert"))
|
||||
return UNARY_CONVERT;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Logistic"))
|
||||
return LOGISTIC;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu"))
|
||||
return CLIPPED_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Swish"))
|
||||
return SWISH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Elu"))
|
||||
return ELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Power"))
|
||||
return POWER;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu"))
|
||||
return LEAKY_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs"))
|
||||
return UNARY_ABS;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Relu"))
|
||||
return RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "SoftRelu"))
|
||||
return SOFT_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Sigmoid"))
|
||||
return SIGMOID;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "TanH"))
|
||||
return TANH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Gelu"))
|
||||
return GELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Silu"))
|
||||
return SILU;
|
||||
else
|
||||
{
|
||||
report_unsupported_elementwise_op_error<ElementwiseOp>();
|
||||
return PASS_THROUGH; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives a gemm padding from a kernel instance type.
|
||||
/// @tparam Instance - A Device Kernel object type.
|
||||
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
|
||||
template <typename Instance>
|
||||
constexpr builder::GemmPadding gemm_spec()
|
||||
requires HasGemmSpec<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::GemmPadding;
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
|
||||
|
||||
switch(gemm_spec)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case MPadding: return M_PADDING;
|
||||
case NPadding: return N_PADDING;
|
||||
case KPadding: return K_PADDING;
|
||||
case MNPadding: return MN_PADDING;
|
||||
case MKPadding: return MK_PADDING;
|
||||
case NKPadding: return NK_PADDING;
|
||||
case MNKPadding: return MNK_PADDING;
|
||||
case OPadding: return O_PADDING;
|
||||
case MOPadding: return MO_PADDING;
|
||||
case NOPadding: return NO_PADDING;
|
||||
case KOPadding: return KO_PADDING;
|
||||
case MNOPadding: return MNO_PADDING;
|
||||
case MKOPadding: return MKO_PADDING;
|
||||
case NKOPadding: return NKO_PADDING;
|
||||
case MNKOPadding: return MNKO_PADDING;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Primary template for extracting convolution traits.
|
||||
/// @details This struct is the main entry point for reflecting on a convolution
|
||||
/// kernel's properties. It is specialized to handle different kinds of input types.
|
||||
template <typename T>
|
||||
struct ConvTraits;
|
||||
|
||||
/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`.
|
||||
/// @details This is the primary specialization used to extract a comprehensive
|
||||
/// set of traits directly from a fully-formed device kernel `Instance` type.
|
||||
/// It uses `InstanceTraits` to access the kernel's template parameters.
|
||||
template <HasInstanceTraits Instance>
|
||||
requires IsXdlFwdConv<InstanceTraits<Instance>>
|
||||
struct ConvTraits<Instance>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
// --- Signature Information ---
|
||||
/// @brief The number of spatial dimensions in the convolution (1, 2, or 3).
|
||||
static constexpr int spatial_dim = InstTraits::kSpatialDim;
|
||||
/// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight).
|
||||
static constexpr builder::ConvDirection direction = conv_direction<Instance>();
|
||||
/// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK).
|
||||
static constexpr auto layout = conv_layout<Instance>();
|
||||
/// @brief The primary data type used in the computation (e.g., FP16, FP32).
|
||||
static constexpr builder::DataType data_type = conv_data_type<Instance>();
|
||||
int spatial_dim;
|
||||
builder::ConvDirection direction;
|
||||
std::array<builder::TensorLayout, 3> layout; // [input, weight, output]
|
||||
builder::DataType data_type;
|
||||
|
||||
static constexpr builder::ElementwiseOperation input_element_op =
|
||||
elementwise_op<typename InstTraits::AElementwiseOperation>();
|
||||
static constexpr builder::ElementwiseOperation weight_element_op =
|
||||
elementwise_op<typename InstTraits::BElementwiseOperation>();
|
||||
static constexpr builder::ElementwiseOperation output_element_op =
|
||||
elementwise_op<typename InstTraits::CDEElementwiseOperation>();
|
||||
builder::ElementwiseOperation input_element_op;
|
||||
builder::ElementwiseOperation weight_element_op;
|
||||
builder::ElementwiseOperation output_element_op;
|
||||
|
||||
/// @brief The GEMM specialization used by the kernel - padding
|
||||
static constexpr auto gemm_padding = gemm_spec<Instance>();
|
||||
/// @brief The convolution-specific specialization (e.g., Default, 1x1).
|
||||
static constexpr auto conv_specialization = conv_spec<Instance>();
|
||||
builder::GemmPadding gemm_padding;
|
||||
builder::ConvSpecialization conv_specialization;
|
||||
|
||||
// --- Algorithm Information ---
|
||||
/// @brief The total number of threads in a thread block (workgroup).
|
||||
static constexpr int thread_block_size = InstTraits::kBlockSize;
|
||||
/// @brief The dimensions of the data tile processed by the thread block.
|
||||
static constexpr DataTileInfo tile_dims = {
|
||||
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
|
||||
int thread_block_size;
|
||||
DataTileInfo tile_dims;
|
||||
|
||||
/// @brief Configuration for the A-matrix (input) tile transfer.
|
||||
static constexpr InputTileTransferInfo a_tile_transfer = {
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
|
||||
InputTileTransferInfo a_tile_transfer;
|
||||
InputTileTransferInfo b_tile_transfer;
|
||||
|
||||
/// @brief Configuration for the B-matrix (weights) tile transfer.
|
||||
static constexpr InputTileTransferInfo b_tile_transfer = {
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
|
||||
WarpGemmParams warp_gemm;
|
||||
|
||||
/// @brief Parameters for the warp-level GEMM computation.
|
||||
static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave};
|
||||
OutputTileTransferInfo c_tile_transfer;
|
||||
|
||||
/// @brief Configuration for the C-matrix (output) tile transfer.
|
||||
static constexpr OutputTileTransferInfo c_tile_transfer = {
|
||||
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
|
||||
|
||||
/// @brief Helper to safely get the pipeline version.
|
||||
/// @details This is only available for some convolutions (e.g., forward).
|
||||
/// If not present in `InstanceTraits`, it returns a default value.
|
||||
template <typename T = InstTraits>
|
||||
static constexpr auto get_pipeline_version()
|
||||
{
|
||||
if constexpr(requires { T::kPipelineVersion; })
|
||||
{
|
||||
return convert_pipeline_version<T::kPipelineVersion>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a default or indicate not available
|
||||
return builder::PipelineVersion::V1;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief The block GEMM pipeline version used by the kernel.
|
||||
static constexpr auto pipeline_version = get_pipeline_version();
|
||||
|
||||
/// @brief Helper to safely get the pipeline scheduler.
|
||||
/// @details This is only available for some convolutions. If not present
|
||||
/// in `InstanceTraits`, it returns a default value.
|
||||
template <typename T = InstTraits>
|
||||
static constexpr auto get_pipeline_scheduler()
|
||||
{
|
||||
if constexpr(requires { T::kPipelineScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<T::kPipelineScheduler>();
|
||||
}
|
||||
else if constexpr(requires { T::kLoopScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<T::kLoopScheduler>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a default or indicate not available
|
||||
return builder::PipelineScheduler::DEFAULT;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief The pipeline scheduler used by the kernel.
|
||||
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
|
||||
builder::PipelineVersion pipeline_version;
|
||||
builder::PipelineScheduler pipeline_scheduler;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = {.m = InstTraits::kMPerBlock,
|
||||
.n = InstTraits::kNPerBlock,
|
||||
.k = InstTraits::kKPerBlock},
|
||||
.a_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
|
||||
.b_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
|
||||
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave},
|
||||
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = {.m = InstTraits::kMPerBlock,
|
||||
.n = InstTraits::kNPerBlock,
|
||||
.k = InstTraits::kKPerBlock},
|
||||
.a_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
|
||||
.b_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
|
||||
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave},
|
||||
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = {.m = InstTraits::kMPerBlock,
|
||||
.n = InstTraits::kNPerBlock,
|
||||
.k = InstTraits::kKPerBlock},
|
||||
.a_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
|
||||
.b_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
|
||||
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave},
|
||||
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,739 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/pipeline_enum.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_types.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
/// @file conv_traits_helpers.hpp
|
||||
/// @brief Helper utilities for extracting convolution traits from kernel instances
|
||||
///
|
||||
/// This file provides compile-time reflection utilities to extract configuration
|
||||
/// information from CK convolution kernel instances and convert them to the builder
|
||||
/// framework's standardized representation.
|
||||
///
|
||||
/// ## Organization
|
||||
///
|
||||
/// The file is organized into the following sections:
|
||||
///
|
||||
/// 1. **Enum Conversions**: Functions to convert CK enums to builder enums
|
||||
/// - Pipeline version conversions (BlockGemmPipelineVersion, PipelineVersion)
|
||||
/// - Pipeline scheduler conversions (BlockGemmPipelineScheduler, LoopScheduler)
|
||||
///
|
||||
/// 2. **Signature Derivation**: Functions to extract signature information from instances
|
||||
/// - Convolution direction (conv_direction)
|
||||
/// - Convolution specialization (conv_spec)
|
||||
/// - Tensor layouts (conv_layout)
|
||||
/// - Data types (conv_data_type)
|
||||
/// - Elementwise operations (elementwise_op)
|
||||
/// - GEMM padding (gemm_spec)
|
||||
///
|
||||
/// 3. **Pipeline Configuration Helpers**: Safe extraction of pipeline parameters
|
||||
/// - Pipeline version extraction (get_pipeline_version)
|
||||
/// - Pipeline scheduler extraction (get_pipeline_scheduler)
|
||||
///
|
||||
/// ## Error Handling Strategy
|
||||
///
|
||||
/// This file uses a specific error handling pattern for compile-time errors:
|
||||
/// - **consteval functions with throw**: Used for error reporting to ensure SFINAE doesn't
|
||||
/// silently ignore errors. The thrown string becomes part of the compiler error message,
|
||||
/// providing clear context to developers.
|
||||
/// - **DO NOT replace with static_assert**: static_assert is silently ignored during SFINAE,
|
||||
/// which would hide errors instead of reporting them clearly.
|
||||
///
|
||||
/// @example
|
||||
/// ```cpp
|
||||
/// using Instance =
|
||||
/// ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<...>;
|
||||
///
|
||||
/// // Extract convolution direction
|
||||
/// constexpr auto dir = conv_direction<Instance>();
|
||||
///
|
||||
/// // Extract data type
|
||||
/// constexpr auto dtype = conv_data_type<Instance>();
|
||||
///
|
||||
/// // Extract layout configuration
|
||||
/// constexpr auto layouts = conv_layout<Instance>();
|
||||
/// ```
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
// ============================================================================
|
||||
// SECTION 1: ENUM CONVERSIONS
|
||||
// ============================================================================
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value.
|
||||
/// @details This function maps CK's block GEMM pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. The pipeline version
|
||||
/// determines the strategy used for data movement and computation overlap in the
|
||||
/// GEMM kernel's main loop.
|
||||
///
|
||||
/// Supported mappings:
|
||||
/// - v1 -> V1
|
||||
/// - v2 -> V2
|
||||
/// - v3 -> V3
|
||||
/// - v4 -> V4
|
||||
/// - v5 -> V5
|
||||
template <ck::BlockGemmPipelineVersion ck_ver>
|
||||
constexpr builder::PipelineVersion convert_pipeline_version()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v3: return V3;
|
||||
case v4: return V4;
|
||||
case v5: return V5;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value.
|
||||
/// @details This function maps CK's general pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. Note that this overload
|
||||
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
|
||||
/// variant, including support for specialized weight-only pipelines.
|
||||
///
|
||||
/// Supported mappings:
|
||||
/// - v1 -> V1
|
||||
/// - v2 -> V2
|
||||
/// - v4 -> V4
|
||||
/// - weight_only -> WEIGHT_ONLY
|
||||
template <ck::PipelineVersion ck_ver>
|
||||
constexpr builder::PipelineVersion convert_pipeline_version()
|
||||
{
|
||||
using enum ck::PipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v4: return V4;
|
||||
case weight_only: return WEIGHT_ONLY;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value.
|
||||
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
|
||||
/// builder framework's standardized scheduler enum. The scheduler determines how work
|
||||
/// is distributed and synchronized within and across wavefronts during pipeline execution.
|
||||
///
|
||||
/// Supported mappings:
|
||||
/// - Intrawave -> INTRAWAVE: Scheduling within a single wavefront
|
||||
/// - Interwave -> INTERWAVE: Coordination across multiple wavefronts
|
||||
template <ck::BlockGemmPipelineScheduler ck_sched>
|
||||
constexpr builder::PipelineScheduler convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Intrawave: return INTRAWAVE;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value.
|
||||
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
|
||||
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
|
||||
/// the main computational loop are scheduled across threads.
|
||||
///
|
||||
/// Supported mappings:
|
||||
/// - Default -> DEFAULT: Standard scheduling strategy
|
||||
/// - Interwave -> INTERWAVE: Cross-wavefront coordination for improved performance
|
||||
template <ck::LoopScheduler ck_sched>
|
||||
constexpr builder::PipelineScheduler convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SECTION 2: SIGNATURE DERIVATION FUNCTIONS
|
||||
// ============================================================================
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Convolution Direction
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Helper function to report unsupported convolution direction with a clear error message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
template <typename Instance>
|
||||
[[noreturn]] consteval void report_unsupported_conv_direction_error()
|
||||
{
|
||||
throw "Unsupported convolution direction detected!\n"
|
||||
"The kernel instance does not have a recognized convolution specialization.\n"
|
||||
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
|
||||
"kConvBwdWeightSpecialization.\n"
|
||||
"Please verify that your kernel instance is properly configured.";
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution direction from a device kernel Instance type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A builder::ConvDirection enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
|
||||
/// @details This function inspects the Instance's InstanceTraits to determine which
|
||||
/// convolution specialization field is present, and returns the corresponding direction.
|
||||
///
|
||||
/// The function checks for the presence of:
|
||||
/// - kConvForwardSpecialization -> FORWARD
|
||||
/// - kConvBwdDataSpecialization -> BACKWARD_DATA
|
||||
/// - kConvBwdWeightSpecialization -> BACKWARD_WEIGHT
|
||||
///
|
||||
/// @note Compilation will fail with a clear error message if the instance does not
|
||||
/// have a recognized convolution specialization field.
|
||||
template <typename Instance>
|
||||
constexpr builder::ConvDirection conv_direction()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
|
||||
return builder::ConvDirection::FORWARD;
|
||||
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
|
||||
return builder::ConvDirection::BACKWARD_DATA;
|
||||
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
|
||||
return builder::ConvDirection::BACKWARD_WEIGHT;
|
||||
else
|
||||
{
|
||||
report_unsupported_conv_direction_error<Instance>();
|
||||
return builder::ConvDirection::FORWARD; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Convolution Specialization
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Helper function to report unsupported convolution specialization with a clear error
|
||||
/// message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
template <typename Instance>
|
||||
[[noreturn]] consteval void report_unsupported_conv_spec_error()
|
||||
{
|
||||
throw "Unsupported convolution specialization detected!\n"
|
||||
"The kernel instance does not have a recognized convolution specialization field.\n"
|
||||
"Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or "
|
||||
"kConvBwdWeightSpecialization.\n"
|
||||
"Please verify that your kernel instance is properly configured.";
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel Instance type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A builder::ConvSpecialization enum value.
|
||||
/// @details This function extracts the specialization enum from the Instance's InstanceTraits
|
||||
/// and converts it to the corresponding builder framework enum.
|
||||
///
|
||||
/// For forward convolutions, supported specializations include:
|
||||
/// - Default, Filter1x1Pad0, Filter1x1Stride1Pad0, Filter3x3, OddC
|
||||
///
|
||||
/// For backward data convolutions:
|
||||
/// - Default, Filter1x1Stride1Pad0
|
||||
///
|
||||
/// For backward weight convolutions:
|
||||
/// - Default, Filter1x1Stride1Pad0, Filter1x1Pad0, OddC
|
||||
template <typename Instance>
|
||||
constexpr builder::ConvSpecialization conv_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using enum builder::ConvSpecialization;
|
||||
|
||||
switch(InstTraits::kConvForwardSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter3x3: return FILTER_3x3;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
using enum builder::ConvSpecialization;
|
||||
|
||||
switch(InstTraits::kConvBwdDataSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
using enum builder::ConvSpecialization;
|
||||
|
||||
switch(InstTraits::kConvBwdWeightSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
report_unsupported_conv_spec_error<Instance>();
|
||||
return builder::ConvSpecialization::DEFAULT; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Tensor Layouts
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Helper function to report unsupported layout combinations with a clear error message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
template <typename A, typename B, typename E, int SpatialDim>
|
||||
[[noreturn]] consteval void report_unsupported_layout_error()
|
||||
{
|
||||
throw "Unsupported convolution layout combination detected!\n"
|
||||
"The combination of ALayout, BLayout, and ELayout template parameters\n"
|
||||
"is not recognized for the given spatial dimension.\n"
|
||||
"Please verify that your convolution instance uses a supported layout configuration.\n"
|
||||
"Check the conv_layout() function for the list of supported layout combinations.";
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel Instance type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array<builder::TensorLayout, 3> containing the layouts for:
|
||||
/// - [0] Input tensor layout
|
||||
/// - [1] Weight tensor layout
|
||||
/// - [2] Output tensor layout
|
||||
/// @details This function examines the Instance's ALayout, BLayout, and ELayout types
|
||||
/// along with the spatial dimension to determine the appropriate layout configuration.
|
||||
///
|
||||
/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions).
|
||||
/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants.
|
||||
///
|
||||
/// @note Compilation will fail with a clear error message if the layout combination
|
||||
/// is not supported for the given spatial dimension.
|
||||
///
|
||||
/// TODO: If we don't check for supported layouts, this function can be simplified.
|
||||
template <typename Instance>
|
||||
constexpr std::array<builder::TensorLayout, 3> conv_layout()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using A = typename InstTraits::ALayout;
|
||||
using B = typename InstTraits::BLayout;
|
||||
using E = typename InstTraits::ELayout;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using enum builder::TensorLayout;
|
||||
|
||||
// Helper to check if layouts match expected types
|
||||
constexpr auto layouts_match = []<typename ExpA, typename ExpB, typename ExpE>() {
|
||||
return std::is_same_v<A, ExpA> && std::is_same_v<B, ExpB> && std::is_same_v<E, ExpE>;
|
||||
};
|
||||
|
||||
// Helper to construct layout array
|
||||
constexpr auto make_layouts = [](auto in, auto weight, auto out) {
|
||||
return std::array<builder::TensorLayout, 3>{in, weight, out};
|
||||
};
|
||||
|
||||
constexpr int spatial_dim = InstTraits::kSpatialDim;
|
||||
|
||||
if constexpr(spatial_dim == 1)
|
||||
{
|
||||
if constexpr(layouts_match.template operator()<ctl::GNWC, ctl::GKXC, ctl::GNWK>())
|
||||
return make_layouts(GNWC, GKXC, GNWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>())
|
||||
return make_layouts(GNWC, GKXC, GNWK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NWGC, ctl::GKXC, ctl::NWGK>())
|
||||
return make_layouts(NWGC, GKXC, NWGK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKXC, ctl::NGKW>())
|
||||
return make_layouts(NGCW, GKXC, NGKW);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKCX, ctl::NGKW>())
|
||||
return make_layouts(NGCW, GKCX, NGKW);
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNWC, GKXC, GNWK); // Unreachable
|
||||
}
|
||||
}
|
||||
else if constexpr(spatial_dim == 2)
|
||||
{
|
||||
if constexpr(layouts_match.template operator()<ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>())
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>())
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>())
|
||||
return make_layouts(NHWGC, GKYXC, NHWGK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>())
|
||||
return make_layouts(NHWGC, GKYXC, NHWGK);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>())
|
||||
return make_layouts(NGCHW, GKYXC, NGKHW);
|
||||
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>())
|
||||
return make_layouts(NGCHW, GKCYX, NGKHW);
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
|
||||
}
|
||||
}
|
||||
else if constexpr(spatial_dim == 3)
|
||||
{
|
||||
if constexpr(layouts_match.template operator()<ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>())
|
||||
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>())
|
||||
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>())
|
||||
return make_layouts(NDHWGC, GKZYXC, NDHWGK);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>())
|
||||
return make_layouts(NGCDHW, GKZYXC, NGKDHW);
|
||||
else if constexpr(layouts_match
|
||||
.template operator()<ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>())
|
||||
return make_layouts(NGCDHW, GKCZYX, NGKDHW);
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
report_unsupported_layout_error<A, B, E, spatial_dim>();
|
||||
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Data Types
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Helper function to report unsupported data type with a clear error message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
template <typename ADataType>
|
||||
[[noreturn]] consteval void report_unsupported_data_type_error()
|
||||
{
|
||||
throw "Unsupported data type detected!\n"
|
||||
"The ADataType is not recognized.\n"
|
||||
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
|
||||
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
|
||||
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
|
||||
"(BF8), "
|
||||
"int8_t (I8), ck::Tuple<int8_t, int8_t> (I8_I8), uint8_t (U8).\n"
|
||||
"Please verify that your kernel instance uses a supported data type.";
|
||||
}
|
||||
|
||||
/// @brief Derives the data type from a device kernel Instance type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A builder::DataType enum value representing the input data type.
|
||||
/// @details This function examines the Instance's ADataType to determine the data type
|
||||
/// used for the input tensor. The function supports various floating-point and integer
|
||||
/// types, including tuple types for mixed-precision operations.
|
||||
///
|
||||
/// Supported data types include:
|
||||
/// - FP16 (ck::half_t)
|
||||
/// - FP16_FP16 (ck::Tuple<ck::half_t, ck::half_t>)
|
||||
/// - BF16 (ck::bhalf_t)
|
||||
/// - BF16_BF16 (ck::Tuple<ck::bhalf_t, ck::bhalf_t>)
|
||||
/// - FP32 (float)
|
||||
/// - FP32_FP32 (ck::Tuple<float, float>)
|
||||
/// - FP64 (double)
|
||||
/// - FP8 (ck::f8_t)
|
||||
/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t)
|
||||
/// - I8 (int8_t)
|
||||
/// - I8_I8 (ck::Tuple<int8_t, int8_t>)
|
||||
/// - U8 (uint8_t)
|
||||
template <typename Instance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ADataType = typename InstTraits::ADataType;
|
||||
using enum builder::DataType;
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck::half_t>)
|
||||
return FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
|
||||
return FP16_FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
|
||||
return BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
|
||||
return BF16_BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, float>)
|
||||
return FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
|
||||
return FP32_FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, double>)
|
||||
return FP64;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
|
||||
return FP8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
|
||||
return BF8;
|
||||
else if constexpr(std::is_same_v<ADataType, int8_t>)
|
||||
return I8;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
|
||||
return I8_I8;
|
||||
else if constexpr(std::is_same_v<ADataType, uint8_t>)
|
||||
return U8;
|
||||
else
|
||||
{
|
||||
report_unsupported_data_type_error<ADataType>();
|
||||
return FP32; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Elementwise Operations
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Helper function to report unsupported elementwise operation with a clear error message.
|
||||
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
|
||||
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
|
||||
template <typename ElementwiseOp>
|
||||
[[noreturn]] consteval void report_unsupported_elementwise_op_error()
|
||||
{
|
||||
throw "Unsupported elementwise operation detected!\n"
|
||||
"The elementwise operation type is not recognized.\n"
|
||||
"Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, "
|
||||
"BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, "
|
||||
"ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, "
|
||||
"UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, "
|
||||
"Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, "
|
||||
"UnaryConvert.\n"
|
||||
"Please verify that your kernel instance uses a supported elementwise operation.";
|
||||
}
|
||||
|
||||
/// @brief Derives the elementwise operation from an operation functor type.
|
||||
/// @tparam ElementwiseOp Elementwise operation functor type.
|
||||
/// @return A builder::ElementwiseOperation enum value corresponding to the operation.
|
||||
/// @details This function uses the operation's type name to determine which elementwise
|
||||
/// operation is being used. The comparison is case-insensitive.
|
||||
///
|
||||
/// Supported operations include:
|
||||
/// - Activation operations: Relu, Sigmoid, Tanh, Gelu, Silu, Elu, Swish, etc.
|
||||
/// - Scaling operations: Scale, ScaleAdd, ConvScale, ConvScaleAdd, etc.
|
||||
/// - Clamping operations: Clamp, AddClamp, etc.
|
||||
/// - Combined operations: Add_Activation_Mul_Clamp, etc.
|
||||
/// - Utility operations: PassThrough, UnaryConvert, etc.
|
||||
///
|
||||
/// TODO: Consider changing this to direct checks on the types, not strings.
|
||||
template <typename ElementwiseOp>
|
||||
constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
using enum builder::ElementwiseOperation;
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
|
||||
if constexpr(detail::case_insensitive_equal(name, "AddClamp"))
|
||||
return ADD_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd"))
|
||||
return ADD_RELU_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
return BIAS_BNORM_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
|
||||
return BILINEAR;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp"))
|
||||
return BIAS_BNORM_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
return CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale"))
|
||||
return CONV_INVSCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScale"))
|
||||
return CONV_SCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd"))
|
||||
return CONV_SCALE_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu"))
|
||||
return CONV_SCALE_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
return SCALE;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd"))
|
||||
return SCALE_ADD;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
return PASS_THROUGH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
return SCALEADD_SCALEADD_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp"))
|
||||
return DYNAMIC_UNARY_OP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp"))
|
||||
return UNARY_COMBINED_OP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp"))
|
||||
return ACTIVATION_MUL2_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp"))
|
||||
return ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp"))
|
||||
return ADD_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp"))
|
||||
return ADD_ACTIVATION_MUL2_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp"))
|
||||
return ADD_MUL_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp"))
|
||||
return ADD_MUL2_ACTIVATION_MUL_CLAMP;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert"))
|
||||
return UNARY_CONVERT;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Logistic"))
|
||||
return LOGISTIC;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu"))
|
||||
return CLIPPED_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Swish"))
|
||||
return SWISH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Elu"))
|
||||
return ELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Power"))
|
||||
return POWER;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu"))
|
||||
return LEAKY_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs"))
|
||||
return UNARY_ABS;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Relu"))
|
||||
return RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "SoftRelu"))
|
||||
return SOFT_RELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Sigmoid"))
|
||||
return SIGMOID;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "TanH"))
|
||||
return TANH;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Gelu"))
|
||||
return GELU;
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Silu"))
|
||||
return SILU;
|
||||
else
|
||||
{
|
||||
report_unsupported_elementwise_op_error<ElementwiseOp>();
|
||||
return PASS_THROUGH; // Unreachable
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// GEMM Padding
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// @brief Derives the GEMM padding specification from a kernel instance type.
|
||||
/// @tparam Instance A device kernel instance type.
|
||||
/// @return A builder::GemmPadding enum value corresponding to the kernel's padding configuration.
|
||||
/// @details This function extracts the GEMM specialization from the Instance's InstanceTraits
|
||||
/// and converts it to the builder framework's GemmPadding enum. The padding specification
|
||||
/// indicates which dimensions (M, N, K, O) are padded to handle non-aligned tensor sizes.
|
||||
///
|
||||
/// Supported padding configurations include:
|
||||
/// - DEFAULT: No padding
|
||||
/// - M_PADDING, N_PADDING, K_PADDING, O_PADDING: Single dimension padding
|
||||
/// - MN_PADDING, MK_PADDING, NK_PADDING, etc.: Two dimension padding
|
||||
/// - MNK_PADDING, MNO_PADDING, etc.: Three dimension padding
|
||||
/// - MNKO_PADDING: All dimensions padded
|
||||
template <typename Instance>
|
||||
constexpr builder::GemmPadding gemm_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::GemmPadding;
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
constexpr auto spec = InstTraits::kGemmSpecialization;
|
||||
|
||||
switch(spec)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case MPadding: return M_PADDING;
|
||||
case NPadding: return N_PADDING;
|
||||
case KPadding: return K_PADDING;
|
||||
case MNPadding: return MN_PADDING;
|
||||
case MKPadding: return MK_PADDING;
|
||||
case NKPadding: return NK_PADDING;
|
||||
case MNKPadding: return MNK_PADDING;
|
||||
case OPadding: return O_PADDING;
|
||||
case MOPadding: return MO_PADDING;
|
||||
case NOPadding: return NO_PADDING;
|
||||
case KOPadding: return KO_PADDING;
|
||||
case MNOPadding: return MNO_PADDING;
|
||||
case MKOPadding: return MKO_PADDING;
|
||||
case NKOPadding: return NKO_PADDING;
|
||||
case MNKOPadding: return MNKO_PADDING;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SECTION 3: PIPELINE CONFIGURATION HELPERS
|
||||
// ============================================================================
|
||||
|
||||
/// @brief Safely extracts the pipeline version from InstanceTraits.
|
||||
/// @tparam InstTraits The InstanceTraits type to extract pipeline version from.
|
||||
/// @return The pipeline version as a builder::PipelineVersion enum value.
|
||||
/// @details This helper function checks if the InstanceTraits has a kPipelineVersion
|
||||
/// field and extracts it if present. If not present, it returns a default value (V1).
|
||||
/// This is necessary because not all convolution types expose pipeline version information.
|
||||
template <typename InstTraits>
|
||||
constexpr builder::PipelineVersion get_pipeline_version()
|
||||
{
|
||||
if constexpr(requires { InstTraits::kPipelineVersion; })
|
||||
{
|
||||
return convert_pipeline_version<InstTraits::kPipelineVersion>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return builder::PipelineVersion::V1;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Safely extracts the pipeline scheduler from InstanceTraits.
|
||||
/// @tparam InstTraits The InstanceTraits type to extract pipeline scheduler from.
|
||||
/// @return The pipeline scheduler as a builder::PipelineScheduler enum value.
|
||||
/// @details This helper function checks if the InstanceTraits has a kPipelineScheduler
|
||||
/// or kLoopScheduler field and extracts it if present. If neither is present, it returns
|
||||
/// a default value (DEFAULT). This is necessary because different convolution types may
|
||||
/// expose scheduler information through different field names.
|
||||
template <typename InstTraits>
|
||||
constexpr builder::PipelineScheduler get_pipeline_scheduler()
|
||||
{
|
||||
if constexpr(requires { InstTraits::kPipelineScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<InstTraits::kPipelineScheduler>();
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kLoopScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<InstTraits::kLoopScheduler>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return builder::PipelineScheduler::DEFAULT;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,8 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
@@ -74,6 +74,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag
|
||||
{
|
||||
};
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
@@ -175,6 +180,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
LoopSched,
|
||||
NumGroupsToMerge>>
|
||||
{
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag;
|
||||
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
|
||||
@@ -78,6 +78,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 device kernel
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag
|
||||
{
|
||||
};
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
@@ -179,6 +184,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
BComputeDataType_,
|
||||
DirectLoad>>
|
||||
{
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag;
|
||||
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
|
||||
@@ -73,6 +73,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor device kernel
|
||||
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag
|
||||
{
|
||||
};
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
@@ -173,6 +178,9 @@ struct InstanceTraits<
|
||||
BComputeDataType_,
|
||||
LoopSched>>
|
||||
{
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag;
|
||||
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user