[CK_BUILDER] Convolution traits. (#3152)

Added:

1. Convolution traits & unit tests
2. Update builder enumerators to have representation of Convolution Kernels properties.
3. Unified builder pipeline version & scheduler enumerators
This commit is contained in:
Adam Osewski
2025-11-05 17:53:06 +01:00
committed by GitHub
parent 3b076b0b74
commit b8527a9236
20 changed files with 1165 additions and 81 deletions

View File

@@ -38,8 +38,8 @@ concept GridwiseXdlGemmDescriptor = requires(T t) {
// Concept for parameter that describe block GEMM problem.
template <typename T>
concept BlockGemmDescriptor = requires(T t) {
{ t.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
{ t.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept for parameters that describe a gridwise WMMA GEMM problem.
@@ -50,7 +50,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) {
{ t.n_per_wmma } -> std::convertible_to<size_t>;
{ t.m_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.n_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.pipeline_version } -> std::convertible_to<GridwiseGemmPipelineVersion>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
};
// Concept for vectorized data transfer for convolution input tensors.
@@ -154,8 +154,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) {
// Concept to check if struct specifies block GEMM.
template <typename T>
concept SpecifiesBlockGemm = requires {
{ T::block_gemm.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
template <typename T>
@@ -180,7 +180,7 @@ concept SpecifiesNumGroupsToMerge = requires {
template <typename T>
concept SpecifiesLoopScheduler = requires {
{ T::loop_scheduler } -> std::convertible_to<LoopScheduler>;
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
};
} // namespace ck_tile::builder

View File

@@ -297,42 +297,42 @@ constexpr BlockGemmSpec SetBlockGemm()
ck::BlockGemmPipelineScheduler scheduler;
ck::BlockGemmPipelineVersion version;
if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE)
{
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
}
else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE)
{
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineScheduler");
static_assert(false, "Unknown PipelineScheduler");
}
if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1)
if constexpr(BG.pipeline_version == PipelineVersion::V1)
{
version = ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2)
else if constexpr(BG.pipeline_version == PipelineVersion::V2)
{
version = ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3)
else if constexpr(BG.pipeline_version == PipelineVersion::V3)
{
version = ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4)
else if constexpr(BG.pipeline_version == PipelineVersion::V4)
{
version = ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5)
else if constexpr(BG.pipeline_version == PipelineVersion::V5)
{
version = ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
@@ -442,17 +442,17 @@ consteval ck::LoopScheduler SetLoopScheduler()
{
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
if constexpr(loop_scheduler == LoopScheduler::DEFAULT)
if constexpr(loop_scheduler == PipelineScheduler::DEFAULT)
{
return ck::LoopScheduler::Default;
}
else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE)
else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE)
{
return ck::LoopScheduler::Interwave;
}
else
{
static_assert(false, "Unknown LoopScheduler");
static_assert(false, "Unknown PipelineScheduler");
}
}
@@ -460,29 +460,29 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
{
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1)
if constexpr(pipeline_version == PipelineVersion::V1)
{
return ck::PipelineVersion::v1;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2)
else if constexpr(pipeline_version == PipelineVersion::V2)
{
return ck::PipelineVersion::v2;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3)
else if constexpr(pipeline_version == PipelineVersion::V3)
{
static_assert(false, "V3 is used only for stream-K.");
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4)
else if constexpr(pipeline_version == PipelineVersion::V4)
{
return ck::PipelineVersion::v4;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY)
else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY)
{
return ck::PipelineVersion::weight_only;
}
else
{
static_assert(false, "Unknown GridwiseGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
}
@@ -566,29 +566,29 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.pipeline_version;
if constexpr(version == BlockGemmPipelineVersion::V1)
if constexpr(version == PipelineVersion::V1)
{
return ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(version == BlockGemmPipelineVersion::V2)
else if constexpr(version == PipelineVersion::V2)
{
return ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(version == BlockGemmPipelineVersion::V3)
else if constexpr(version == PipelineVersion::V3)
{
return ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(version == BlockGemmPipelineVersion::V4)
else if constexpr(version == PipelineVersion::V4)
{
return ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(version == BlockGemmPipelineVersion::V5)
else if constexpr(version == PipelineVersion::V5)
{
return ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
}

View File

@@ -0,0 +1,719 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/conv_factory.hpp>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/types.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/utility/loop_scheduler.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
namespace ck_tile::reflect::conv {
// Helper metafunctions to convert from ck enums to builder enums
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5).
/// @details This function maps CK's block GEMM pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. The pipeline version
/// determines the strategy used for data movement and computation overlap in the
/// GEMM kernel's main loop.
template <ck::BlockGemmPipelineVersion ck_ver>
constexpr auto convert_pipeline_version()
{
using enum ck::BlockGemmPipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v3)
return V3;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == v5)
return V5;
}
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY).
/// @details This function maps CK's general pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. Note that this overload
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
/// variant, including support for specialized weight-only pipelines.
template <ck::PipelineVersion ck_ver>
constexpr auto convert_pipeline_version()
{
using enum ck::PipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == weight_only)
return WEIGHT_ONLY;
}
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE).
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
/// builder framework's standardized scheduler enum. The scheduler determines how work
/// is distributed and synchronized within and across wavefronts during pipeline execution.
/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates
/// across multiple wavefronts.
template <ck::BlockGemmPipelineScheduler ck_sched>
constexpr auto convert_pipeline_scheduler()
{
using enum ck::BlockGemmPipelineScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Intrawave)
return INTRAWAVE;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
}
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE).
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
/// the main computational loop are scheduled across threads. DEFAULT uses the standard
/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved
/// performance in certain scenarios.
template <ck::LoopScheduler ck_sched>
constexpr auto convert_pipeline_scheduler()
{
using enum ck::LoopScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Default)
return DEFAULT;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
}
/// @brief Helper structures for organizing trait data with domain-specific naming
/// @brief Data tile dimensions processed by a workgroup.
/// @details This struct defines the M, N, and K dimensions of the data tile
/// that a single workgroup (thread block) is responsible for processing in the
/// underlying GEMM computation.
struct DataTileInfo
{
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
};
/// @brief Dimensions for an input data tile transfer.
/// @details Defines the shape of the input tile (A or B matrix) as it is
/// transferred from global memory to LDS. The tile is conceptually divided
/// into k0 and k1 dimensions.
struct InputTileTransferDimensions
{
int k0; ///< The outer dimension of K, where K = k0 * k1.
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
///< memory.
};
/// @brief Parameters governing the transfer of an input tile.
/// @details This struct holds configuration details for how an input tile is
/// loaded from global memory into LDS, including thread clustering, memory
/// access patterns, and vectorization settings.
struct InputTileTransferParams
{
int k1; ///< The inner K dimension size, often matching the vectorization width.
std::array<int, 3>
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
///< many threads are arranged on each axis.
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
///< input tensor dimensions.
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
///< dimension to read first).
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
///< (the contiguous dimension).
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
///< elements accessed per thread per instruction.
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
///< dimension.
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
///< conflicts.
};
/// @brief Complete information for an input tile transfer.
/// @details Combines the dimensional information and transfer parameters for
/// a full description of an input tile's journey from global memory to LDS.
struct InputTileTransferInfo
{
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
};
/// @brief Parameters for the warp-level GEMM computation.
/// @details Defines the configuration of the GEMM operation performed by each
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
struct WarpGemmParams
{
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
///< wavefront (MXdlPerWave).
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
///< wavefront (NXdlPerWave).
};
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
/// @details Configures how many MFMA instruction results are processed per
/// wave in each iteration of the CShuffle routine.
struct WarpShuffleParams
{
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
///< per shuffle iteration.
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
///< per shuffle iteration.
};
/// @brief Information for the output tile transfer (CShuffle).
/// @details Describes how the final computed tile (C matrix) is written out from
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
struct OutputTileTransferInfo
{
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
///< data into the output tensor.
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
///< output tensor.
};
// Helper metafunctions to derive signature information from Instance types
/// @brief Derives the convolution direction from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
template <typename Instance>
constexpr builder::ConvDirection conv_direction()
{
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
{
return builder::ConvDirection::FORWARD;
}
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
{
return builder::ConvDirection::BACKWARD_DATA;
}
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
{
return builder::ConvDirection::BACKWARD_WEIGHT;
}
else
{
return builder::ConvDirection::FORWARD; // Default fallback
}
}
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or
/// `builder::ConvBwdWeightSpecialization` enum value.
template <typename Instance>
constexpr auto conv_spec()
{
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
if constexpr(InstTraits::kConvForwardSpecialization == Default)
{
return builder::ConvFwdSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3)
{
return builder::ConvFwdSpecialization::FILTER_3x3;
}
}
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
if constexpr(InstTraits::kConvBwdDataSpecialization == Default)
{
return builder::ConvBwdDataSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
}
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
if constexpr(InstTraits::kConvBwdWeightSpecialization == Default)
{
return builder::ConvBwdWeightSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC)
{
return builder::ConvBwdWeightSpecialization::ODD_C;
}
}
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts.
template <typename Instance>
constexpr auto conv_layout()
{
using InstTraits = InstanceTraits<Instance>;
using ALayout = typename InstTraits::ALayout;
using BLayout = typename InstTraits::BLayout;
using ELayout = typename InstTraits::ELayout;
namespace ctc = ck::tensor_layout::convolution;
if constexpr(InstTraits::kSpatialDim == 1)
{
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
std::is_same_v<ELayout, ctc::GNWK>)
{
return builder::GroupConvLayout1D::GNWC_GKXC_GNWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
{
return builder::GroupConvLayout1D::NWGC_GKXC_NWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
{
return builder::GroupConvLayout1D::NGCW_GKXC_NGKW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
{
return builder::GroupConvLayout1D::NGCW_GKCX_NGKW;
}
}
else if constexpr(InstTraits::kSpatialDim == 2)
{
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::GNHWK>)
{
return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NHWGK>)
{
return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKCYX> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW;
}
}
else if constexpr(InstTraits::kSpatialDim == 3)
{
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::GNDHWK>)
{
return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NDHWGK>)
{
return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKCZYX> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW;
}
}
}
/// @brief Derives the data type from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
template <typename Instance>
constexpr builder::DataType conv_data_type()
{
using InstTraits = InstanceTraits<Instance>;
using ADataType = typename InstTraits::ADataType;
if constexpr(std::is_same_v<ADataType, ck::half_t>)
{
return builder::DataType::FP16;
}
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
{
return builder::DataType::BF16;
}
else if constexpr(std::is_same_v<ADataType, float>)
{
return builder::DataType::FP32;
}
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
{
return builder::DataType::FP8;
}
else if constexpr(std::is_same_v<ADataType, int8_t>)
{
return builder::DataType::I8;
}
else if constexpr(std::is_same_v<ADataType, uint8_t>)
{
return builder::DataType::U8;
}
else
{
// Default fallback
return builder::DataType::FP32;
}
}
/// @brief Derives the elementwise operation from op type.
/// @tparam ElementwiseOp Elementwise operation functor type.
/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation.
template <typename ElementwiseOp>
constexpr builder::ElementwiseOperation elementwise_op()
{
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
if constexpr(detail::case_insensitive_equal(name, "Bias"))
{
return builder::ElementwiseOperation::BIAS;
}
else if constexpr(detail::case_insensitive_equal(name, "BiasClamp"))
{
return builder::ElementwiseOperation::BIAS_CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
{
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
{
return builder::ElementwiseOperation::BILINEAR;
}
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
{
return builder::ElementwiseOperation::CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
{
return builder::ElementwiseOperation::SCALE;
}
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
{
return builder::ElementwiseOperation::PASS_THROUGH;
}
}
/// @brief Derives a gemm padding from a kernel instance type.
/// @tparam Instance - A Device Kernel object type.
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
template <typename Instance>
constexpr builder::GemmPadding gemm_spec()
{
using InstTraits = InstanceTraits<Instance>;
using enum builder::GemmPadding;
using enum ck::tensor_operation::device::GemmSpecialization;
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
if constexpr(gemm_spec == Default)
{
return DEFAULT;
}
else if constexpr(gemm_spec == MPadding)
{
return M_PADDING;
}
else if constexpr(gemm_spec == NPadding)
{
return N_PADDING;
}
else if constexpr(gemm_spec == KPadding)
{
return K_PADDING;
}
else if constexpr(gemm_spec == MNPadding)
{
return MN_PADDING;
}
else if constexpr(gemm_spec == MKPadding)
{
return MK_PADDING;
}
else if constexpr(gemm_spec == NKPadding)
{
return NK_PADDING;
}
else if constexpr(gemm_spec == MNKPadding)
{
return MNK_PADDING;
}
else if constexpr(gemm_spec == OPadding)
{
return O_PADDING;
}
else if constexpr(gemm_spec == MOPadding)
{
return MO_PADDING;
}
else if constexpr(gemm_spec == NOPadding)
{
return NO_PADDING;
}
else if constexpr(gemm_spec == KOPadding)
{
return KO_PADDING;
}
else if constexpr(gemm_spec == MNOPadding)
{
return MNO_PADDING;
}
else if constexpr(gemm_spec == MKOPadding)
{
return MKO_PADDING;
}
else if constexpr(gemm_spec == NKOPadding)
{
return NKO_PADDING;
}
else if constexpr(gemm_spec == MNKOPadding)
{
return MNKO_PADDING;
}
}
/// @brief Primary template for extracting convolution traits.
/// @details This struct is the main entry point for reflecting on a convolution
/// kernel's properties. It is specialized to handle different kinds of input types.
template <typename T>
struct ConvTraits;
/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`.
/// @details This is the primary specialization used to extract a comprehensive
/// set of traits directly from a fully-formed device kernel `Instance` type.
/// It uses `InstanceTraits` to access the kernel's template parameters.
template <typename Instance>
requires requires { typename InstanceTraits<Instance>; }
struct ConvTraits<Instance>
{
using InstTraits = InstanceTraits<Instance>;
// --- Signature Information ---
/// @brief The number of spatial dimensions in the convolution (1, 2, or 3).
static constexpr int spatial_dim = InstTraits::kSpatialDim;
/// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight).
static constexpr builder::ConvDirection direction = conv_direction<Instance>();
/// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK).
static constexpr auto layout = conv_layout<Instance>();
/// @brief The primary data type used in the computation (e.g., FP16, FP32).
static constexpr builder::DataType data_type = conv_data_type<Instance>();
static constexpr builder::ElementwiseOperation input_element_op =
elementwise_op<typename InstTraits::AElementwiseOperation>();
static constexpr builder::ElementwiseOperation weight_element_op =
elementwise_op<typename InstTraits::BElementwiseOperation>();
static constexpr builder::ElementwiseOperation output_element_op =
elementwise_op<typename InstTraits::CDEElementwiseOperation>();
/// @brief The GEMM specialization used by the kernel - padding
static constexpr auto gemm_padding = gemm_spec<Instance>();
/// @brief The convolution-specific specialization (e.g., Default, 1x1).
static constexpr auto conv_specialization = conv_spec<Instance>();
// --- Algorithm Information ---
/// @brief The total number of threads in a thread block (workgroup).
static constexpr int thread_block_size = InstTraits::kBlockSize;
/// @brief The dimensions of the data tile processed by the thread block.
static constexpr DataTileInfo tile_dims = {
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
/// @brief Configuration for the A-matrix (input) tile transfer.
static constexpr InputTileTransferInfo a_tile_transfer = {
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
/// @brief Configuration for the B-matrix (weights) tile transfer.
static constexpr InputTileTransferInfo b_tile_transfer = {
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
/// @brief Parameters for the warp-level GEMM computation.
static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave};
/// @brief Configuration for the C-matrix (output) tile transfer.
static constexpr OutputTileTransferInfo c_tile_transfer = {
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
/// @brief Helper to safely get the pipeline version.
/// @details This is only available for some convolutions (e.g., forward).
/// If not present in `InstanceTraits`, it returns a default value.
template <typename T = InstTraits>
static constexpr auto get_pipeline_version()
{
if constexpr(requires { T::kPipelineVersion; })
{
return convert_pipeline_version<T::kPipelineVersion>();
}
else
{
// Return a default or indicate not available
return builder::PipelineVersion::V1;
}
}
/// @brief The block GEMM pipeline version used by the kernel.
static constexpr auto pipeline_version = get_pipeline_version();
/// @brief Helper to safely get the pipeline scheduler.
/// @details This is only available for some convolutions. If not present
/// in `InstanceTraits`, it returns a default value.
template <typename T = InstTraits>
static constexpr auto get_pipeline_scheduler()
{
if constexpr(requires { T::kPipelineScheduler; })
{
return convert_pipeline_scheduler<T::kPipelineScheduler>();
}
else if constexpr(requires { T::kLoopScheduler; })
{
return convert_pipeline_scheduler<T::kLoopScheduler>();
}
else
{
// Return a default or indicate not available
return builder::PipelineScheduler::DEFAULT;
}
}
/// @brief The pipeline scheduler used by the kernel.
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
};
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
/// @details This specialization provides backward compatibility for reflecting
/// on kernels defined via the `ConvBuilder` interface. It works by first
/// creating the `Instance` via the builder's factory, and then delegating
/// all trait extraction to the `ConvTraits<Instance>` specialization.
template <builder::ConvSignatureDescriptor auto SIGNATURE,
builder::ConvAlgorithmDescriptor auto ALGORITHM,
builder::StringLiteral VERSION>
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
{
using Factory = builder::ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
using Instance = typename Factory::Instance;
// Delegate to Instance-based ConvTraits
using InstanceConvTraits = ConvTraits<Instance>;
// Forward all members from Instance-based traits
static constexpr int spatial_dim = InstanceConvTraits::spatial_dim;
static constexpr builder::ConvDirection direction = InstanceConvTraits::direction;
static constexpr auto layout = InstanceConvTraits::layout;
static constexpr builder::DataType data_type = InstanceConvTraits::data_type;
static constexpr builder::ElementwiseOperation input_element_op =
InstanceConvTraits::input_element_op;
static constexpr builder::ElementwiseOperation weight_element_op =
InstanceConvTraits::weight_element_op;
static constexpr builder::ElementwiseOperation output_element_op =
InstanceConvTraits::output_element_op;
static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding;
static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization;
static constexpr int thread_block_size = InstanceConvTraits::thread_block_size;
static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims;
static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer;
static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer;
static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm;
static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer;
static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version;
static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler;
};
} // namespace ck_tile::reflect::conv

View File

@@ -14,18 +14,9 @@
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <type_traits>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include "instance_traits_util.hpp"
#include <concepts>
namespace ck_tile::reflect {

View File

@@ -15,6 +15,7 @@
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include

View File

@@ -14,6 +14,7 @@
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include

View File

@@ -9,9 +9,11 @@
#include <array>
#include <string>
#include <concepts>
#include <string_view>
#include <sstream>
#include <type_traits>
#include <climits>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
@@ -371,4 +373,30 @@ constexpr std::string type_or_type_tuple_name()
}
}
/// @brief Makes a case insensitive comparison of two string views.
/// @param a First string view
/// @param b Second string view
/// @return Whether two string views a equal case insensitive
constexpr bool case_insensitive_equal(std::string_view a, std::string_view b)
{
if(a.size() != b.size())
return false;
for(size_t i = 0; i < a.size(); ++i)
{
char c1 = a[i];
char c2 = b[i];
// Convert to lowercase for comparison
if(c1 >= 'A' && c1 <= 'Z')
c1 += 32;
if(c2 >= 'A' && c2 <= 'Z')
c2 += 32;
if(c1 != c2)
return false;
}
return true;
}
} // namespace ck_tile::reflect::detail

View File

@@ -128,29 +128,14 @@ enum class ElementwiseOperation
PASS_THROUGH
};
// Enums for the current block GEMM pipeline versions.
enum class BlockGemmPipelineVersion
// Enums for pipeline versions & schedulers
enum class PipelineVersion
{
V1,
V2,
V3,
V4,
V5
};
enum struct BlockGemmPipelineScheduler
{
INTRAWAVE,
INTERWAVE,
};
// Enums for the gridwise GEMM pipeline versions.
enum class GridwiseGemmPipelineVersion
{
V1,
V2,
V3, // Only used in stream-K implementation
V4,
V5,
WEIGHT_ONLY
};
@@ -186,9 +171,47 @@ enum class ConvFwdSpecialization
FILTER_3x3
};
enum class LoopScheduler
// Enums for the backward data convolution specialization.
enum class ConvBwdDataSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
};
// Enums for the backward weight convolution specialization.
enum class ConvBwdWeightSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
FILTER_1X1_PAD0,
ODD_C,
};
// Enums for the Gemm padding.
enum class GemmPadding
{
DEFAULT,
M_PADDING,
N_PADDING,
K_PADDING,
MN_PADDING,
MK_PADDING,
NK_PADDING,
MNK_PADDING,
O_PADDING,
MO_PADDING,
NO_PADDING,
KO_PADDING,
MNO_PADDING,
MKO_PADDING,
NKO_PADDING,
MNKO_PADDING,
};
enum class PipelineScheduler
{
DEFAULT,
INTRAWAVE,
INTERWAVE
};

View File

@@ -64,6 +64,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
add_ck_builder_test(test_conv_traits
conv/test_conv_traits.cpp)
# Function to add all test_ckb targets to a list
function(collect_test_ckb_targets result_var)
# Get all targets in current directory

View File

@@ -27,7 +27,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V2,
PipelineVersion::V2,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
PipelineVersion::V1,
ConvFwdSpecialization::DEFAULT>();
}
@@ -47,7 +47,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V5,
PipelineVersion::V5,
ConvFwdSpecialization::FILTER_3x3>();
}

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
PipelineVersion::V3,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
PipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
PipelineVersion::V3,
ConvFwdSpecialization::DEFAULT>();
}

View File

@@ -26,7 +26,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
PipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -26,7 +26,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
PipelineVersion::V1,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -0,0 +1,316 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <concepts>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
namespace {
using ::testing::ElementsAre;
// Test fixture for ConvTraits tests
class ConvTraitsTest : public ::testing::Test
{
};
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::PipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
false>; // DirectLoad
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
// Verify A tile transfer info
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding);
// Verify B tile transfer info
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding);
// Verify warp GEMM params
EXPECT_EQ(Traits::warp_gemm.gemm_m, 32);
EXPECT_EQ(Traits::warp_gemm.gemm_n, 32);
EXPECT_EQ(Traits::warp_gemm.m_iter, 4);
EXPECT_EQ(Traits::warp_gemm.n_iter, 4);
// Verify output tile transfer info
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8);
// Verify pipeline configuration
EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE);
EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default, // LoopSched
1>; // NumGroupsToMerge
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default>; // LoopSched
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
}
} // anonymous namespace

View File

@@ -49,14 +49,14 @@ struct GridwiseWmmaGemm
size_t n_per_wmma = 0;
size_t m_wmma_per_wave = 0;
size_t n_wmma_per_wave = 0;
GridwiseGemmPipelineVersion pipeline_version;
PipelineVersion pipeline_version;
};
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
struct BlockGemm
{
BlockGemmPipelineVersion pipeline_version;
BlockGemmPipelineScheduler scheduler;
PipelineVersion pipeline_version;
PipelineScheduler scheduler;
};
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
@@ -156,7 +156,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
GemmSpecialization gemm_specialization;
size_t num_gemm_k_prefetch_stages;
size_t num_groups_to_merge;
LoopScheduler loop_scheduler;
PipelineScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
@@ -191,7 +191,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
size_t num_gemm_k_prefetch_stages;
LoopScheduler loop_scheduler;
PipelineScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);

View File

@@ -16,7 +16,7 @@ using namespace test;
// Common test implementation
template <ConvSignature FwdConvSignature,
ThreadBlock FwdThreadBlock,
BlockGemmPipelineVersion FwdPipelineVersion,
PipelineVersion FwdPipelineVersion,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
{
@@ -52,7 +52,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
.src_access_order_b = {1, 0, 2}};
constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion,
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock,
@@ -73,13 +73,13 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
// Verify pipeline version is correct
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
if(FwdPipelineVersion == PipelineVersion::V1)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
else if(FwdPipelineVersion == PipelineVersion::V3)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
else if(FwdPipelineVersion == PipelineVersion::V4)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
else if(FwdPipelineVersion == PipelineVersion::V5)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
// Verify specialization is correct
@@ -140,7 +140,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle()
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages = 1,
.num_groups_to_merge = 2,
.loop_scheduler = LoopScheduler::DEFAULT};
.loop_scheduler = PipelineScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -176,7 +176,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
.n_per_wmma = 32,
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1,
.pipeline_version = GridwiseGemmPipelineVersion::V1};
.pipeline_version = PipelineVersion::V1};
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
@@ -209,7 +209,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
.fwd_specialization = FwdConvSpecialization,
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages = 1,
.loop_scheduler = LoopScheduler::DEFAULT};
.loop_scheduler = PipelineScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;