[CK_BUILDER] Convolution traits. (#3152)

Added:

1. Convolution traits & unit tests
2. Update builder enumerators to have representation of Convolution Kernels properties.
3. Unified builder pipeline version & scheduler enumerators
This commit is contained in:
Adam Osewski
2025-11-05 17:53:06 +01:00
committed by GitHub
parent 3b076b0b74
commit b8527a9236
20 changed files with 1165 additions and 81 deletions

View File

@@ -38,8 +38,8 @@ concept GridwiseXdlGemmDescriptor = requires(T t) {
// Concept for parameter that describe block GEMM problem.
template <typename T>
concept BlockGemmDescriptor = requires(T t) {
{ t.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
{ t.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept for parameters that describe a gridwise WMMA GEMM problem.
@@ -50,7 +50,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) {
{ t.n_per_wmma } -> std::convertible_to<size_t>;
{ t.m_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.n_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.pipeline_version } -> std::convertible_to<GridwiseGemmPipelineVersion>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
};
// Concept for vectorized data transfer for convolution input tensors.
@@ -154,8 +154,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) {
// Concept to check if struct specifies block GEMM.
template <typename T>
concept SpecifiesBlockGemm = requires {
{ T::block_gemm.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
template <typename T>
@@ -180,7 +180,7 @@ concept SpecifiesNumGroupsToMerge = requires {
template <typename T>
concept SpecifiesLoopScheduler = requires {
{ T::loop_scheduler } -> std::convertible_to<LoopScheduler>;
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
};
} // namespace ck_tile::builder

View File

@@ -297,42 +297,42 @@ constexpr BlockGemmSpec SetBlockGemm()
ck::BlockGemmPipelineScheduler scheduler;
ck::BlockGemmPipelineVersion version;
if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE)
{
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
}
else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE)
{
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineScheduler");
static_assert(false, "Unknown PipelineScheduler");
}
if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1)
if constexpr(BG.pipeline_version == PipelineVersion::V1)
{
version = ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2)
else if constexpr(BG.pipeline_version == PipelineVersion::V2)
{
version = ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3)
else if constexpr(BG.pipeline_version == PipelineVersion::V3)
{
version = ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4)
else if constexpr(BG.pipeline_version == PipelineVersion::V4)
{
version = ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5)
else if constexpr(BG.pipeline_version == PipelineVersion::V5)
{
version = ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
@@ -442,17 +442,17 @@ consteval ck::LoopScheduler SetLoopScheduler()
{
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
if constexpr(loop_scheduler == LoopScheduler::DEFAULT)
if constexpr(loop_scheduler == PipelineScheduler::DEFAULT)
{
return ck::LoopScheduler::Default;
}
else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE)
else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE)
{
return ck::LoopScheduler::Interwave;
}
else
{
static_assert(false, "Unknown LoopScheduler");
static_assert(false, "Unknown PipelineScheduler");
}
}
@@ -460,29 +460,29 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
{
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1)
if constexpr(pipeline_version == PipelineVersion::V1)
{
return ck::PipelineVersion::v1;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2)
else if constexpr(pipeline_version == PipelineVersion::V2)
{
return ck::PipelineVersion::v2;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3)
else if constexpr(pipeline_version == PipelineVersion::V3)
{
static_assert(false, "V3 is used only for stream-K.");
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4)
else if constexpr(pipeline_version == PipelineVersion::V4)
{
return ck::PipelineVersion::v4;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY)
else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY)
{
return ck::PipelineVersion::weight_only;
}
else
{
static_assert(false, "Unknown GridwiseGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
}
@@ -566,29 +566,29 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.pipeline_version;
if constexpr(version == BlockGemmPipelineVersion::V1)
if constexpr(version == PipelineVersion::V1)
{
return ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(version == BlockGemmPipelineVersion::V2)
else if constexpr(version == PipelineVersion::V2)
{
return ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(version == BlockGemmPipelineVersion::V3)
else if constexpr(version == PipelineVersion::V3)
{
return ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(version == BlockGemmPipelineVersion::V4)
else if constexpr(version == PipelineVersion::V4)
{
return ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(version == BlockGemmPipelineVersion::V5)
else if constexpr(version == PipelineVersion::V5)
{
return ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
}

View File

@@ -0,0 +1,719 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/conv_factory.hpp>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/types.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/utility/loop_scheduler.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
namespace ck_tile::reflect::conv {
// Helper metafunctions to convert from ck enums to builder enums
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5).
/// @details This function maps CK's block GEMM pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. The pipeline version
/// determines the strategy used for data movement and computation overlap in the
/// GEMM kernel's main loop.
template <ck::BlockGemmPipelineVersion ck_ver>
constexpr auto convert_pipeline_version()
{
using enum ck::BlockGemmPipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v3)
return V3;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == v5)
return V5;
}
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY).
/// @details This function maps CK's general pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. Note that this overload
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
/// variant, including support for specialized weight-only pipelines.
template <ck::PipelineVersion ck_ver>
constexpr auto convert_pipeline_version()
{
using enum ck::PipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == weight_only)
return WEIGHT_ONLY;
}
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE).
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
/// builder framework's standardized scheduler enum. The scheduler determines how work
/// is distributed and synchronized within and across wavefronts during pipeline execution.
/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates
/// across multiple wavefronts.
template <ck::BlockGemmPipelineScheduler ck_sched>
constexpr auto convert_pipeline_scheduler()
{
using enum ck::BlockGemmPipelineScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Intrawave)
return INTRAWAVE;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
}
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE).
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
/// the main computational loop are scheduled across threads. DEFAULT uses the standard
/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved
/// performance in certain scenarios.
template <ck::LoopScheduler ck_sched>
constexpr auto convert_pipeline_scheduler()
{
using enum ck::LoopScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Default)
return DEFAULT;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
}
/// @brief Helper structures for organizing trait data with domain-specific naming
/// @brief Data tile dimensions processed by a workgroup.
/// @details This struct defines the M, N, and K dimensions of the data tile
/// that a single workgroup (thread block) is responsible for processing in the
/// underlying GEMM computation.
struct DataTileInfo
{
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
};
/// @brief Dimensions for an input data tile transfer.
/// @details Defines the shape of the input tile (A or B matrix) as it is
/// transferred from global memory to LDS. The tile is conceptually divided
/// into k0 and k1 dimensions.
struct InputTileTransferDimensions
{
int k0; ///< The outer dimension of K, where K = k0 * k1.
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
///< memory.
};
/// @brief Parameters governing the transfer of an input tile.
/// @details This struct holds configuration details for how an input tile is
/// loaded from global memory into LDS, including thread clustering, memory
/// access patterns, and vectorization settings.
struct InputTileTransferParams
{
int k1; ///< The inner K dimension size, often matching the vectorization width.
std::array<int, 3>
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
///< many threads are arranged on each axis.
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
///< input tensor dimensions.
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
///< dimension to read first).
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
///< (the contiguous dimension).
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
///< elements accessed per thread per instruction.
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
///< dimension.
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
///< conflicts.
};
/// @brief Complete information for an input tile transfer.
/// @details Combines the dimensional information and transfer parameters for
/// a full description of an input tile's journey from global memory to LDS.
struct InputTileTransferInfo
{
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
};
/// @brief Parameters for the warp-level GEMM computation.
/// @details Defines the configuration of the GEMM operation performed by each
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
struct WarpGemmParams
{
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
///< wavefront (MXdlPerWave).
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
///< wavefront (NXdlPerWave).
};
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
/// @details Configures how many MFMA instruction results are processed per
/// wave in each iteration of the CShuffle routine.
struct WarpShuffleParams
{
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
///< per shuffle iteration.
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
///< per shuffle iteration.
};
/// @brief Information for the output tile transfer (CShuffle).
/// @details Describes how the final computed tile (C matrix) is written out from
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
struct OutputTileTransferInfo
{
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
///< data into the output tensor.
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
///< output tensor.
};
// Helper metafunctions to derive signature information from Instance types
/// @brief Derives the convolution direction from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
template <typename Instance>
constexpr builder::ConvDirection conv_direction()
{
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
{
return builder::ConvDirection::FORWARD;
}
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
{
return builder::ConvDirection::BACKWARD_DATA;
}
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
{
return builder::ConvDirection::BACKWARD_WEIGHT;
}
else
{
return builder::ConvDirection::FORWARD; // Default fallback
}
}
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or
/// `builder::ConvBwdWeightSpecialization` enum value.
template <typename Instance>
constexpr auto conv_spec()
{
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
if constexpr(InstTraits::kConvForwardSpecialization == Default)
{
return builder::ConvFwdSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3)
{
return builder::ConvFwdSpecialization::FILTER_3x3;
}
}
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
if constexpr(InstTraits::kConvBwdDataSpecialization == Default)
{
return builder::ConvBwdDataSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
}
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
if constexpr(InstTraits::kConvBwdWeightSpecialization == Default)
{
return builder::ConvBwdWeightSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC)
{
return builder::ConvBwdWeightSpecialization::ODD_C;
}
}
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts.
template <typename Instance>
constexpr auto conv_layout()
{
using InstTraits = InstanceTraits<Instance>;
using ALayout = typename InstTraits::ALayout;
using BLayout = typename InstTraits::BLayout;
using ELayout = typename InstTraits::ELayout;
namespace ctc = ck::tensor_layout::convolution;
if constexpr(InstTraits::kSpatialDim == 1)
{
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
std::is_same_v<ELayout, ctc::GNWK>)
{
return builder::GroupConvLayout1D::GNWC_GKXC_GNWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
{
return builder::GroupConvLayout1D::NWGC_GKXC_NWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
{
return builder::GroupConvLayout1D::NGCW_GKXC_NGKW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
{
return builder::GroupConvLayout1D::NGCW_GKCX_NGKW;
}
}
else if constexpr(InstTraits::kSpatialDim == 2)
{
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::GNHWK>)
{
return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NHWGK>)
{
return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKCYX> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW;
}
}
else if constexpr(InstTraits::kSpatialDim == 3)
{
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::GNDHWK>)
{
return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NDHWGK>)
{
return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKCZYX> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW;
}
}
}
/// @brief Derives the data type from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
template <typename Instance>
constexpr builder::DataType conv_data_type()
{
using InstTraits = InstanceTraits<Instance>;
using ADataType = typename InstTraits::ADataType;
if constexpr(std::is_same_v<ADataType, ck::half_t>)
{
return builder::DataType::FP16;
}
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
{
return builder::DataType::BF16;
}
else if constexpr(std::is_same_v<ADataType, float>)
{
return builder::DataType::FP32;
}
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
{
return builder::DataType::FP8;
}
else if constexpr(std::is_same_v<ADataType, int8_t>)
{
return builder::DataType::I8;
}
else if constexpr(std::is_same_v<ADataType, uint8_t>)
{
return builder::DataType::U8;
}
else
{
// Default fallback
return builder::DataType::FP32;
}
}
/// @brief Derives the elementwise operation from op type.
/// @tparam ElementwiseOp Elementwise operation functor type.
/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation.
template <typename ElementwiseOp>
constexpr builder::ElementwiseOperation elementwise_op()
{
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
if constexpr(detail::case_insensitive_equal(name, "Bias"))
{
return builder::ElementwiseOperation::BIAS;
}
else if constexpr(detail::case_insensitive_equal(name, "BiasClamp"))
{
return builder::ElementwiseOperation::BIAS_CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
{
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
{
return builder::ElementwiseOperation::BILINEAR;
}
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
{
return builder::ElementwiseOperation::CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
{
return builder::ElementwiseOperation::SCALE;
}
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
{
return builder::ElementwiseOperation::PASS_THROUGH;
}
}
/// @brief Derives a gemm padding from a kernel instance type.
/// @tparam Instance - A Device Kernel object type.
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
template <typename Instance>
constexpr builder::GemmPadding gemm_spec()
{
using InstTraits = InstanceTraits<Instance>;
using enum builder::GemmPadding;
using enum ck::tensor_operation::device::GemmSpecialization;
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
if constexpr(gemm_spec == Default)
{
return DEFAULT;
}
else if constexpr(gemm_spec == MPadding)
{
return M_PADDING;
}
else if constexpr(gemm_spec == NPadding)
{
return N_PADDING;
}
else if constexpr(gemm_spec == KPadding)
{
return K_PADDING;
}
else if constexpr(gemm_spec == MNPadding)
{
return MN_PADDING;
}
else if constexpr(gemm_spec == MKPadding)
{
return MK_PADDING;
}
else if constexpr(gemm_spec == NKPadding)
{
return NK_PADDING;
}
else if constexpr(gemm_spec == MNKPadding)
{
return MNK_PADDING;
}
else if constexpr(gemm_spec == OPadding)
{
return O_PADDING;
}
else if constexpr(gemm_spec == MOPadding)
{
return MO_PADDING;
}
else if constexpr(gemm_spec == NOPadding)
{
return NO_PADDING;
}
else if constexpr(gemm_spec == KOPadding)
{
return KO_PADDING;
}
else if constexpr(gemm_spec == MNOPadding)
{
return MNO_PADDING;
}
else if constexpr(gemm_spec == MKOPadding)
{
return MKO_PADDING;
}
else if constexpr(gemm_spec == NKOPadding)
{
return NKO_PADDING;
}
else if constexpr(gemm_spec == MNKOPadding)
{
return MNKO_PADDING;
}
}
/// @brief Primary template for extracting convolution traits.
/// @details This struct is the main entry point for reflecting on a convolution
/// kernel's properties. It is specialized to handle different kinds of input types.
template <typename T>
struct ConvTraits;
/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`.
/// @details This is the primary specialization used to extract a comprehensive
/// set of traits directly from a fully-formed device kernel `Instance` type.
/// It uses `InstanceTraits` to access the kernel's template parameters.
template <typename Instance>
requires requires { typename InstanceTraits<Instance>; }
struct ConvTraits<Instance>
{
using InstTraits = InstanceTraits<Instance>;
// --- Signature Information ---
/// @brief The number of spatial dimensions in the convolution (1, 2, or 3).
static constexpr int spatial_dim = InstTraits::kSpatialDim;
/// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight).
static constexpr builder::ConvDirection direction = conv_direction<Instance>();
/// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK).
static constexpr auto layout = conv_layout<Instance>();
/// @brief The primary data type used in the computation (e.g., FP16, FP32).
static constexpr builder::DataType data_type = conv_data_type<Instance>();
static constexpr builder::ElementwiseOperation input_element_op =
elementwise_op<typename InstTraits::AElementwiseOperation>();
static constexpr builder::ElementwiseOperation weight_element_op =
elementwise_op<typename InstTraits::BElementwiseOperation>();
static constexpr builder::ElementwiseOperation output_element_op =
elementwise_op<typename InstTraits::CDEElementwiseOperation>();
/// @brief The GEMM specialization used by the kernel - padding
static constexpr auto gemm_padding = gemm_spec<Instance>();
/// @brief The convolution-specific specialization (e.g., Default, 1x1).
static constexpr auto conv_specialization = conv_spec<Instance>();
// --- Algorithm Information ---
/// @brief The total number of threads in a thread block (workgroup).
static constexpr int thread_block_size = InstTraits::kBlockSize;
/// @brief The dimensions of the data tile processed by the thread block.
static constexpr DataTileInfo tile_dims = {
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
/// @brief Configuration for the A-matrix (input) tile transfer.
static constexpr InputTileTransferInfo a_tile_transfer = {
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
/// @brief Configuration for the B-matrix (weights) tile transfer.
static constexpr InputTileTransferInfo b_tile_transfer = {
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
/// @brief Parameters for the warp-level GEMM computation.
static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave};
/// @brief Configuration for the C-matrix (output) tile transfer.
static constexpr OutputTileTransferInfo c_tile_transfer = {
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
/// @brief Helper to safely get the pipeline version.
/// @details This is only available for some convolutions (e.g., forward).
/// If not present in `InstanceTraits`, it returns a default value.
template <typename T = InstTraits>
static constexpr auto get_pipeline_version()
{
if constexpr(requires { T::kPipelineVersion; })
{
return convert_pipeline_version<T::kPipelineVersion>();
}
else
{
// Return a default or indicate not available
return builder::PipelineVersion::V1;
}
}
/// @brief The block GEMM pipeline version used by the kernel.
static constexpr auto pipeline_version = get_pipeline_version();
/// @brief Helper to safely get the pipeline scheduler.
/// @details This is only available for some convolutions. If not present
/// in `InstanceTraits`, it returns a default value.
template <typename T = InstTraits>
static constexpr auto get_pipeline_scheduler()
{
if constexpr(requires { T::kPipelineScheduler; })
{
return convert_pipeline_scheduler<T::kPipelineScheduler>();
}
else if constexpr(requires { T::kLoopScheduler; })
{
return convert_pipeline_scheduler<T::kLoopScheduler>();
}
else
{
// Return a default or indicate not available
return builder::PipelineScheduler::DEFAULT;
}
}
/// @brief The pipeline scheduler used by the kernel.
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
};
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
/// @details This specialization provides backward compatibility for reflecting
/// on kernels defined via the `ConvBuilder` interface. It works by first
/// creating the `Instance` via the builder's factory, and then delegating
/// all trait extraction to the `ConvTraits<Instance>` specialization.
template <builder::ConvSignatureDescriptor auto SIGNATURE,
builder::ConvAlgorithmDescriptor auto ALGORITHM,
builder::StringLiteral VERSION>
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
{
using Factory = builder::ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
using Instance = typename Factory::Instance;
// Delegate to Instance-based ConvTraits
using InstanceConvTraits = ConvTraits<Instance>;
// Forward all members from Instance-based traits
static constexpr int spatial_dim = InstanceConvTraits::spatial_dim;
static constexpr builder::ConvDirection direction = InstanceConvTraits::direction;
static constexpr auto layout = InstanceConvTraits::layout;
static constexpr builder::DataType data_type = InstanceConvTraits::data_type;
static constexpr builder::ElementwiseOperation input_element_op =
InstanceConvTraits::input_element_op;
static constexpr builder::ElementwiseOperation weight_element_op =
InstanceConvTraits::weight_element_op;
static constexpr builder::ElementwiseOperation output_element_op =
InstanceConvTraits::output_element_op;
static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding;
static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization;
static constexpr int thread_block_size = InstanceConvTraits::thread_block_size;
static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims;
static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer;
static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer;
static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm;
static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer;
static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version;
static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler;
};
} // namespace ck_tile::reflect::conv

View File

@@ -14,18 +14,9 @@
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <type_traits>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include "instance_traits_util.hpp"
#include <concepts>
namespace ck_tile::reflect {

View File

@@ -15,6 +15,7 @@
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include

View File

@@ -14,6 +14,7 @@
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include

View File

@@ -9,9 +9,11 @@
#include <array>
#include <string>
#include <concepts>
#include <string_view>
#include <sstream>
#include <type_traits>
#include <climits>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
@@ -371,4 +373,30 @@ constexpr std::string type_or_type_tuple_name()
}
}
/// @brief Makes a case insensitive comparison of two string views.
/// @param a First string view
/// @param b Second string view
/// @return Whether two string views a equal case insensitive
constexpr bool case_insensitive_equal(std::string_view a, std::string_view b)
{
if(a.size() != b.size())
return false;
for(size_t i = 0; i < a.size(); ++i)
{
char c1 = a[i];
char c2 = b[i];
// Convert to lowercase for comparison
if(c1 >= 'A' && c1 <= 'Z')
c1 += 32;
if(c2 >= 'A' && c2 <= 'Z')
c2 += 32;
if(c1 != c2)
return false;
}
return true;
}
} // namespace ck_tile::reflect::detail

View File

@@ -128,29 +128,14 @@ enum class ElementwiseOperation
PASS_THROUGH
};
// Enums for the current block GEMM pipeline versions.
enum class BlockGemmPipelineVersion
// Enums for pipeline versions & schedulers
enum class PipelineVersion
{
V1,
V2,
V3,
V4,
V5
};
enum struct BlockGemmPipelineScheduler
{
INTRAWAVE,
INTERWAVE,
};
// Enums for the gridwise GEMM pipeline versions.
enum class GridwiseGemmPipelineVersion
{
V1,
V2,
V3, // Only used in stream-K implementation
V4,
V5,
WEIGHT_ONLY
};
@@ -186,9 +171,47 @@ enum class ConvFwdSpecialization
FILTER_3x3
};
enum class LoopScheduler
// Enums for the backward data convolution specialization.
enum class ConvBwdDataSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
};
// Enums for the backward weight convolution specialization.
enum class ConvBwdWeightSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
FILTER_1X1_PAD0,
ODD_C,
};
// Enums for the Gemm padding.
enum class GemmPadding
{
DEFAULT,
M_PADDING,
N_PADDING,
K_PADDING,
MN_PADDING,
MK_PADDING,
NK_PADDING,
MNK_PADDING,
O_PADDING,
MO_PADDING,
NO_PADDING,
KO_PADDING,
MNO_PADDING,
MKO_PADDING,
NKO_PADDING,
MNKO_PADDING,
};
enum class PipelineScheduler
{
DEFAULT,
INTRAWAVE,
INTERWAVE
};