mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_BUILDER] Convolution traits. (#3152)
Added: 1. Convolution traits & unit tests 2. Update builder enumerators to have representation of Convolution Kernels properties. 3. Unified builder pipeline version & scheduler enumerators
This commit is contained in:
@@ -38,8 +38,8 @@ concept GridwiseXdlGemmDescriptor = requires(T t) {
|
||||
// Concept for parameter that describe block GEMM problem.
|
||||
template <typename T>
|
||||
concept BlockGemmDescriptor = requires(T t) {
|
||||
{ t.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
|
||||
{ t.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
// Concept for parameters that describe a gridwise WMMA GEMM problem.
|
||||
@@ -50,7 +50,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) {
|
||||
{ t.n_per_wmma } -> std::convertible_to<size_t>;
|
||||
{ t.m_wmma_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.n_wmma_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.pipeline_version } -> std::convertible_to<GridwiseGemmPipelineVersion>;
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
};
|
||||
|
||||
// Concept for vectorized data transfer for convolution input tensors.
|
||||
@@ -154,8 +154,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) {
|
||||
// Concept to check if struct specifies block GEMM.
|
||||
template <typename T>
|
||||
concept SpecifiesBlockGemm = requires {
|
||||
{ T::block_gemm.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
|
||||
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -180,7 +180,7 @@ concept SpecifiesNumGroupsToMerge = requires {
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesLoopScheduler = requires {
|
||||
{ T::loop_scheduler } -> std::convertible_to<LoopScheduler>;
|
||||
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -297,42 +297,42 @@ constexpr BlockGemmSpec SetBlockGemm()
|
||||
ck::BlockGemmPipelineScheduler scheduler;
|
||||
ck::BlockGemmPipelineVersion version;
|
||||
|
||||
if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
|
||||
if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE)
|
||||
{
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
}
|
||||
else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
|
||||
else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE)
|
||||
{
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineScheduler");
|
||||
static_assert(false, "Unknown PipelineScheduler");
|
||||
}
|
||||
|
||||
if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1)
|
||||
if constexpr(BG.pipeline_version == PipelineVersion::V1)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2)
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V2)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3)
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V3)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4)
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V4)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5)
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V5)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineVersion");
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
}
|
||||
|
||||
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
|
||||
@@ -442,17 +442,17 @@ consteval ck::LoopScheduler SetLoopScheduler()
|
||||
{
|
||||
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
|
||||
|
||||
if constexpr(loop_scheduler == LoopScheduler::DEFAULT)
|
||||
if constexpr(loop_scheduler == PipelineScheduler::DEFAULT)
|
||||
{
|
||||
return ck::LoopScheduler::Default;
|
||||
}
|
||||
else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE)
|
||||
else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE)
|
||||
{
|
||||
return ck::LoopScheduler::Interwave;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown LoopScheduler");
|
||||
static_assert(false, "Unknown PipelineScheduler");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -460,29 +460,29 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
|
||||
if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1)
|
||||
if constexpr(pipeline_version == PipelineVersion::V1)
|
||||
{
|
||||
return ck::PipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2)
|
||||
else if constexpr(pipeline_version == PipelineVersion::V2)
|
||||
{
|
||||
return ck::PipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3)
|
||||
else if constexpr(pipeline_version == PipelineVersion::V3)
|
||||
{
|
||||
static_assert(false, "V3 is used only for stream-K.");
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4)
|
||||
else if constexpr(pipeline_version == PipelineVersion::V4)
|
||||
{
|
||||
return ck::PipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY)
|
||||
else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY)
|
||||
{
|
||||
return ck::PipelineVersion::weight_only;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown GridwiseGemmPipelineVersion");
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -566,29 +566,29 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.pipeline_version;
|
||||
|
||||
if constexpr(version == BlockGemmPipelineVersion::V1)
|
||||
if constexpr(version == PipelineVersion::V1)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V2)
|
||||
else if constexpr(version == PipelineVersion::V2)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V3)
|
||||
else if constexpr(version == PipelineVersion::V3)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V4)
|
||||
else if constexpr(version == PipelineVersion::V4)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V5)
|
||||
else if constexpr(version == PipelineVersion::V5)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineVersion");
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,719 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <ck_tile/builder/conv_builder.hpp>
|
||||
#include <ck_tile/builder/conv_factory.hpp>
|
||||
#include <ck_tile/builder/conv_signature_concepts.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/types.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/utility/loop_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
// Helper metafunctions to convert from ck enums to builder enums
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5).
|
||||
/// @details This function maps CK's block GEMM pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. The pipeline version
|
||||
/// determines the strategy used for data movement and computation overlap in the
|
||||
/// GEMM kernel's main loop.
|
||||
template <ck::BlockGemmPipelineVersion ck_ver>
|
||||
constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
if constexpr(ck_ver == v1)
|
||||
return V1;
|
||||
else if constexpr(ck_ver == v2)
|
||||
return V2;
|
||||
else if constexpr(ck_ver == v3)
|
||||
return V3;
|
||||
else if constexpr(ck_ver == v4)
|
||||
return V4;
|
||||
else if constexpr(ck_ver == v5)
|
||||
return V5;
|
||||
}
|
||||
|
||||
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY).
|
||||
/// @details This function maps CK's general pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. Note that this overload
|
||||
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
|
||||
/// variant, including support for specialized weight-only pipelines.
|
||||
template <ck::PipelineVersion ck_ver>
|
||||
constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::PipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
if constexpr(ck_ver == v1)
|
||||
return V1;
|
||||
else if constexpr(ck_ver == v2)
|
||||
return V2;
|
||||
else if constexpr(ck_ver == v4)
|
||||
return V4;
|
||||
else if constexpr(ck_ver == weight_only)
|
||||
return WEIGHT_ONLY;
|
||||
}
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE).
|
||||
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
|
||||
/// builder framework's standardized scheduler enum. The scheduler determines how work
|
||||
/// is distributed and synchronized within and across wavefronts during pipeline execution.
|
||||
/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates
|
||||
/// across multiple wavefronts.
|
||||
template <ck::BlockGemmPipelineScheduler ck_sched>
|
||||
constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
if constexpr(ck_sched == Intrawave)
|
||||
return INTRAWAVE;
|
||||
else if constexpr(ck_sched == Interwave)
|
||||
return INTERWAVE;
|
||||
}
|
||||
|
||||
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE).
|
||||
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
|
||||
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
|
||||
/// the main computational loop are scheduled across threads. DEFAULT uses the standard
|
||||
/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved
|
||||
/// performance in certain scenarios.
|
||||
template <ck::LoopScheduler ck_sched>
|
||||
constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
if constexpr(ck_sched == Default)
|
||||
return DEFAULT;
|
||||
else if constexpr(ck_sched == Interwave)
|
||||
return INTERWAVE;
|
||||
}
|
||||
|
||||
/// @brief Helper structures for organizing trait data with domain-specific naming
|
||||
|
||||
/// @brief Data tile dimensions processed by a workgroup.
|
||||
/// @details This struct defines the M, N, and K dimensions of the data tile
|
||||
/// that a single workgroup (thread block) is responsible for processing in the
|
||||
/// underlying GEMM computation.
|
||||
struct DataTileInfo
|
||||
{
|
||||
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
|
||||
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
|
||||
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
|
||||
};
|
||||
|
||||
/// @brief Dimensions for an input data tile transfer.
|
||||
/// @details Defines the shape of the input tile (A or B matrix) as it is
|
||||
/// transferred from global memory to LDS. The tile is conceptually divided
|
||||
/// into k0 and k1 dimensions.
|
||||
struct InputTileTransferDimensions
|
||||
{
|
||||
int k0; ///< The outer dimension of K, where K = k0 * k1.
|
||||
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
|
||||
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
|
||||
///< memory.
|
||||
};
|
||||
|
||||
/// @brief Parameters governing the transfer of an input tile.
|
||||
/// @details This struct holds configuration details for how an input tile is
|
||||
/// loaded from global memory into LDS, including thread clustering, memory
|
||||
/// access patterns, and vectorization settings.
|
||||
struct InputTileTransferParams
|
||||
{
|
||||
int k1; ///< The inner K dimension size, often matching the vectorization width.
|
||||
std::array<int, 3>
|
||||
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
|
||||
///< many threads are arranged on each axis.
|
||||
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
|
||||
///< input tensor dimensions.
|
||||
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
|
||||
///< dimension to read first).
|
||||
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
|
||||
///< (the contiguous dimension).
|
||||
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
|
||||
///< elements accessed per thread per instruction.
|
||||
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
|
||||
///< dimension.
|
||||
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
|
||||
///< conflicts.
|
||||
};
|
||||
|
||||
/// @brief Complete information for an input tile transfer.
|
||||
/// @details Combines the dimensional information and transfer parameters for
|
||||
/// a full description of an input tile's journey from global memory to LDS.
|
||||
struct InputTileTransferInfo
|
||||
{
|
||||
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
|
||||
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
|
||||
};
|
||||
|
||||
/// @brief Parameters for the warp-level GEMM computation.
|
||||
/// @details Defines the configuration of the GEMM operation performed by each
|
||||
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
|
||||
struct WarpGemmParams
|
||||
{
|
||||
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
|
||||
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
|
||||
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
|
||||
///< wavefront (MXdlPerWave).
|
||||
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
|
||||
///< wavefront (NXdlPerWave).
|
||||
};
|
||||
|
||||
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
|
||||
/// @details Configures how many MFMA instruction results are processed per
|
||||
/// wave in each iteration of the CShuffle routine.
|
||||
struct WarpShuffleParams
|
||||
{
|
||||
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
|
||||
///< per shuffle iteration.
|
||||
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
|
||||
///< per shuffle iteration.
|
||||
};
|
||||
|
||||
/// @brief Information for the output tile transfer (CShuffle).
|
||||
/// @details Describes how the final computed tile (C matrix) is written out from
|
||||
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
|
||||
struct OutputTileTransferInfo
|
||||
{
|
||||
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
|
||||
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
|
||||
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
|
||||
///< data into the output tensor.
|
||||
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
|
||||
///< output tensor.
|
||||
};
|
||||
|
||||
// Helper metafunctions to derive signature information from Instance types
|
||||
|
||||
/// @brief Derives the convolution direction from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
|
||||
template <typename Instance>
|
||||
constexpr builder::ConvDirection conv_direction()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::FORWARD;
|
||||
}
|
||||
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::BACKWARD_DATA;
|
||||
}
|
||||
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::BACKWARD_WEIGHT;
|
||||
}
|
||||
else
|
||||
{
|
||||
return builder::ConvDirection::FORWARD; // Default fallback
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or
|
||||
/// `builder::ConvBwdWeightSpecialization` enum value.
|
||||
template <typename Instance>
|
||||
constexpr auto conv_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvForwardSpecialization == Default)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_1X1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_3x3;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvBwdDataSpecialization == Default)
|
||||
{
|
||||
return builder::ConvBwdDataSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvBwdWeightSpecialization == Default)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::ODD_C;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts.
|
||||
template <typename Instance>
|
||||
constexpr auto conv_layout()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ALayout = typename InstTraits::ALayout;
|
||||
using BLayout = typename InstTraits::BLayout;
|
||||
using ELayout = typename InstTraits::ELayout;
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
if constexpr(InstTraits::kSpatialDim == 1)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNWK>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::GNWC_GKXC_GNWK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NWGC_GKXC_NWGK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NGCW_GKXC_NGKW;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NGCW_GKCX_NGKW;
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNHWK>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NHWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW;
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 3)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNDHWK>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NDHWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCZYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the data type from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
|
||||
template <typename Instance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ADataType = typename InstTraits::ADataType;
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck::half_t>)
|
||||
{
|
||||
return builder::DataType::FP16;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
|
||||
{
|
||||
return builder::DataType::BF16;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, float>)
|
||||
{
|
||||
return builder::DataType::FP32;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
|
||||
{
|
||||
return builder::DataType::FP8;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, int8_t>)
|
||||
{
|
||||
return builder::DataType::I8;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, uint8_t>)
|
||||
{
|
||||
return builder::DataType::U8;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Default fallback
|
||||
return builder::DataType::FP32;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the elementwise operation from op type.
|
||||
/// @tparam ElementwiseOp Elementwise operation functor type.
|
||||
/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation.
|
||||
template <typename ElementwiseOp>
|
||||
constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
if constexpr(detail::case_insensitive_equal(name, "Bias"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasClamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS_CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BILINEAR;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
{
|
||||
return builder::ElementwiseOperation::SCALE;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
{
|
||||
return builder::ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives a gemm padding from a kernel instance type.
|
||||
/// @tparam Instance - A Device Kernel object type.
|
||||
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
|
||||
template <typename Instance>
|
||||
constexpr builder::GemmPadding gemm_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::GemmPadding;
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
|
||||
|
||||
if constexpr(gemm_spec == Default)
|
||||
{
|
||||
return DEFAULT;
|
||||
}
|
||||
else if constexpr(gemm_spec == MPadding)
|
||||
{
|
||||
return M_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NPadding)
|
||||
{
|
||||
return N_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == KPadding)
|
||||
{
|
||||
return K_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNPadding)
|
||||
{
|
||||
return MN_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MKPadding)
|
||||
{
|
||||
return MK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NKPadding)
|
||||
{
|
||||
return NK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNKPadding)
|
||||
{
|
||||
return MNK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == OPadding)
|
||||
{
|
||||
return O_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MOPadding)
|
||||
{
|
||||
return MO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NOPadding)
|
||||
{
|
||||
return NO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == KOPadding)
|
||||
{
|
||||
return KO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNOPadding)
|
||||
{
|
||||
return MNO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MKOPadding)
|
||||
{
|
||||
return MKO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NKOPadding)
|
||||
{
|
||||
return NKO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNKOPadding)
|
||||
{
|
||||
return MNKO_PADDING;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Primary template for extracting convolution traits.
|
||||
/// @details This struct is the main entry point for reflecting on a convolution
|
||||
/// kernel's properties. It is specialized to handle different kinds of input types.
|
||||
template <typename T>
|
||||
struct ConvTraits;
|
||||
|
||||
/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`.
|
||||
/// @details This is the primary specialization used to extract a comprehensive
|
||||
/// set of traits directly from a fully-formed device kernel `Instance` type.
|
||||
/// It uses `InstanceTraits` to access the kernel's template parameters.
|
||||
template <typename Instance>
|
||||
requires requires { typename InstanceTraits<Instance>; }
|
||||
struct ConvTraits<Instance>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
// --- Signature Information ---
|
||||
/// @brief The number of spatial dimensions in the convolution (1, 2, or 3).
|
||||
static constexpr int spatial_dim = InstTraits::kSpatialDim;
|
||||
/// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight).
|
||||
static constexpr builder::ConvDirection direction = conv_direction<Instance>();
|
||||
/// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK).
|
||||
static constexpr auto layout = conv_layout<Instance>();
|
||||
/// @brief The primary data type used in the computation (e.g., FP16, FP32).
|
||||
static constexpr builder::DataType data_type = conv_data_type<Instance>();
|
||||
|
||||
static constexpr builder::ElementwiseOperation input_element_op =
|
||||
elementwise_op<typename InstTraits::AElementwiseOperation>();
|
||||
static constexpr builder::ElementwiseOperation weight_element_op =
|
||||
elementwise_op<typename InstTraits::BElementwiseOperation>();
|
||||
static constexpr builder::ElementwiseOperation output_element_op =
|
||||
elementwise_op<typename InstTraits::CDEElementwiseOperation>();
|
||||
|
||||
/// @brief The GEMM specialization used by the kernel - padding
|
||||
static constexpr auto gemm_padding = gemm_spec<Instance>();
|
||||
/// @brief The convolution-specific specialization (e.g., Default, 1x1).
|
||||
static constexpr auto conv_specialization = conv_spec<Instance>();
|
||||
|
||||
// --- Algorithm Information ---
|
||||
/// @brief The total number of threads in a thread block (workgroup).
|
||||
static constexpr int thread_block_size = InstTraits::kBlockSize;
|
||||
/// @brief The dimensions of the data tile processed by the thread block.
|
||||
static constexpr DataTileInfo tile_dims = {
|
||||
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
|
||||
|
||||
/// @brief Configuration for the A-matrix (input) tile transfer.
|
||||
static constexpr InputTileTransferInfo a_tile_transfer = {
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
|
||||
|
||||
/// @brief Configuration for the B-matrix (weights) tile transfer.
|
||||
static constexpr InputTileTransferInfo b_tile_transfer = {
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
|
||||
|
||||
/// @brief Parameters for the warp-level GEMM computation.
|
||||
static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave};
|
||||
|
||||
/// @brief Configuration for the C-matrix (output) tile transfer.
|
||||
static constexpr OutputTileTransferInfo c_tile_transfer = {
|
||||
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
|
||||
|
||||
/// @brief Helper to safely get the pipeline version.
|
||||
/// @details This is only available for some convolutions (e.g., forward).
|
||||
/// If not present in `InstanceTraits`, it returns a default value.
|
||||
template <typename T = InstTraits>
|
||||
static constexpr auto get_pipeline_version()
|
||||
{
|
||||
if constexpr(requires { T::kPipelineVersion; })
|
||||
{
|
||||
return convert_pipeline_version<T::kPipelineVersion>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a default or indicate not available
|
||||
return builder::PipelineVersion::V1;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief The block GEMM pipeline version used by the kernel.
|
||||
static constexpr auto pipeline_version = get_pipeline_version();
|
||||
|
||||
/// @brief Helper to safely get the pipeline scheduler.
|
||||
/// @details This is only available for some convolutions. If not present
|
||||
/// in `InstanceTraits`, it returns a default value.
|
||||
template <typename T = InstTraits>
|
||||
static constexpr auto get_pipeline_scheduler()
|
||||
{
|
||||
if constexpr(requires { T::kPipelineScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<T::kPipelineScheduler>();
|
||||
}
|
||||
else if constexpr(requires { T::kLoopScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<T::kLoopScheduler>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a default or indicate not available
|
||||
return builder::PipelineScheduler::DEFAULT;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief The pipeline scheduler used by the kernel.
|
||||
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
|
||||
};
|
||||
|
||||
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
|
||||
/// @details This specialization provides backward compatibility for reflecting
|
||||
/// on kernels defined via the `ConvBuilder` interface. It works by first
|
||||
/// creating the `Instance` via the builder's factory, and then delegating
|
||||
/// all trait extraction to the `ConvTraits<Instance>` specialization.
|
||||
template <builder::ConvSignatureDescriptor auto SIGNATURE,
|
||||
builder::ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
builder::StringLiteral VERSION>
|
||||
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
|
||||
{
|
||||
using Factory = builder::ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
|
||||
using Instance = typename Factory::Instance;
|
||||
|
||||
// Delegate to Instance-based ConvTraits
|
||||
using InstanceConvTraits = ConvTraits<Instance>;
|
||||
|
||||
// Forward all members from Instance-based traits
|
||||
static constexpr int spatial_dim = InstanceConvTraits::spatial_dim;
|
||||
static constexpr builder::ConvDirection direction = InstanceConvTraits::direction;
|
||||
static constexpr auto layout = InstanceConvTraits::layout;
|
||||
static constexpr builder::DataType data_type = InstanceConvTraits::data_type;
|
||||
|
||||
static constexpr builder::ElementwiseOperation input_element_op =
|
||||
InstanceConvTraits::input_element_op;
|
||||
static constexpr builder::ElementwiseOperation weight_element_op =
|
||||
InstanceConvTraits::weight_element_op;
|
||||
static constexpr builder::ElementwiseOperation output_element_op =
|
||||
InstanceConvTraits::output_element_op;
|
||||
|
||||
static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding;
|
||||
static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization;
|
||||
|
||||
static constexpr int thread_block_size = InstanceConvTraits::thread_block_size;
|
||||
static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims;
|
||||
static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer;
|
||||
static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer;
|
||||
static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm;
|
||||
static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer;
|
||||
static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version;
|
||||
static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -14,18 +14,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include "instance_traits_util.hpp"
|
||||
#include <concepts>
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
// This file will be included by the device implementation header, so we cannot include
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
// This file will be included by the device implementation header, so we cannot include
|
||||
|
||||
@@ -9,9 +9,11 @@
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <climits>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
@@ -371,4 +373,30 @@ constexpr std::string type_or_type_tuple_name()
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Makes a case insensitive comparison of two string views.
|
||||
/// @param a First string view
|
||||
/// @param b Second string view
|
||||
/// @return Whether two string views a equal case insensitive
|
||||
constexpr bool case_insensitive_equal(std::string_view a, std::string_view b)
|
||||
{
|
||||
if(a.size() != b.size())
|
||||
return false;
|
||||
|
||||
for(size_t i = 0; i < a.size(); ++i)
|
||||
{
|
||||
char c1 = a[i];
|
||||
char c2 = b[i];
|
||||
|
||||
// Convert to lowercase for comparison
|
||||
if(c1 >= 'A' && c1 <= 'Z')
|
||||
c1 += 32;
|
||||
if(c2 >= 'A' && c2 <= 'Z')
|
||||
c2 += 32;
|
||||
|
||||
if(c1 != c2)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::detail
|
||||
|
||||
@@ -128,29 +128,14 @@ enum class ElementwiseOperation
|
||||
PASS_THROUGH
|
||||
};
|
||||
|
||||
// Enums for the current block GEMM pipeline versions.
|
||||
enum class BlockGemmPipelineVersion
|
||||
// Enums for pipeline versions & schedulers
|
||||
enum class PipelineVersion
|
||||
{
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V4,
|
||||
V5
|
||||
};
|
||||
|
||||
enum struct BlockGemmPipelineScheduler
|
||||
{
|
||||
INTRAWAVE,
|
||||
INTERWAVE,
|
||||
};
|
||||
|
||||
// Enums for the gridwise GEMM pipeline versions.
|
||||
enum class GridwiseGemmPipelineVersion
|
||||
{
|
||||
V1,
|
||||
V2,
|
||||
V3, // Only used in stream-K implementation
|
||||
V4,
|
||||
V5,
|
||||
WEIGHT_ONLY
|
||||
};
|
||||
|
||||
@@ -186,9 +171,47 @@ enum class ConvFwdSpecialization
|
||||
FILTER_3x3
|
||||
};
|
||||
|
||||
enum class LoopScheduler
|
||||
// Enums for the backward data convolution specialization.
|
||||
enum class ConvBwdDataSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
};
|
||||
|
||||
// Enums for the backward weight convolution specialization.
|
||||
enum class ConvBwdWeightSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
FILTER_1X1_PAD0,
|
||||
ODD_C,
|
||||
};
|
||||
|
||||
// Enums for the Gemm padding.
|
||||
enum class GemmPadding
|
||||
{
|
||||
DEFAULT,
|
||||
M_PADDING,
|
||||
N_PADDING,
|
||||
K_PADDING,
|
||||
MN_PADDING,
|
||||
MK_PADDING,
|
||||
NK_PADDING,
|
||||
MNK_PADDING,
|
||||
O_PADDING,
|
||||
MO_PADDING,
|
||||
NO_PADDING,
|
||||
KO_PADDING,
|
||||
MNO_PADDING,
|
||||
MKO_PADDING,
|
||||
NKO_PADDING,
|
||||
MNKO_PADDING,
|
||||
};
|
||||
|
||||
enum class PipelineScheduler
|
||||
{
|
||||
DEFAULT,
|
||||
INTRAWAVE,
|
||||
INTERWAVE
|
||||
};
|
||||
|
||||
|
||||
@@ -64,6 +64,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam
|
||||
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
|
||||
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
|
||||
|
||||
add_ck_builder_test(test_conv_traits
|
||||
conv/test_conv_traits.cpp)
|
||||
|
||||
# Function to add all test_ckb targets to a list
|
||||
function(collect_test_ckb_targets result_var)
|
||||
# Get all targets in current directory
|
||||
|
||||
@@ -27,7 +27,7 @@ TEST(FwdConvInstances,
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V2,
|
||||
PipelineVersion::V2,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
PipelineVersion::V1,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ TEST(FwdConvInstances,
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V5,
|
||||
PipelineVersion::V5,
|
||||
ConvFwdSpecialization::FILTER_3x3>();
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
PipelineVersion::V3,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
PipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
PipelineVersion::V3,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ TEST(FwdConvInstances,
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
PipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ TEST(FwdConvInstances,
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
PipelineVersion::V1,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
|
||||
316
experimental/builder/test/conv/test_conv_traits.cpp
Normal file
316
experimental/builder/test/conv/test_conv_traits.cpp
Normal file
@@ -0,0 +1,316 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <concepts>
|
||||
|
||||
#include <ck_tile/builder/reflect/conv_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
// Test fixture for ConvTraits tests
|
||||
class ConvTraitsTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::PipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
false>; // DirectLoad
|
||||
|
||||
// Use ConvTraits to extract compile-time information
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(Traits::tile_dims.m, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.n, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.k, 16);
|
||||
|
||||
// Verify A tile transfer info
|
||||
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify B tile transfer info
|
||||
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify warp GEMM params
|
||||
EXPECT_EQ(Traits::warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(Traits::warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(Traits::warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(Traits::warp_gemm.n_iter, 4);
|
||||
|
||||
// Verify output tile transfer info
|
||||
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE);
|
||||
EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1);
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
||||
1, // NumGemmKPrefetchStage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
ck::LoopScheduler::Default, // LoopSched
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
// Use ConvTraits to extract compile-time information
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(Traits::tile_dims.m, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.n, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.k, 16);
|
||||
}
|
||||
// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
||||
1, // NumGemmKPrefetchStage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CDEBlockTransferClusterLengths
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
ck::LoopScheduler::Default>; // LoopSched
|
||||
|
||||
// Use ConvTraits to extract compile-time information
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(Traits::tile_dims.m, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.n, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.k, 16);
|
||||
}
|
||||
} // anonymous namespace
|
||||
@@ -49,14 +49,14 @@ struct GridwiseWmmaGemm
|
||||
size_t n_per_wmma = 0;
|
||||
size_t m_wmma_per_wave = 0;
|
||||
size_t n_wmma_per_wave = 0;
|
||||
GridwiseGemmPipelineVersion pipeline_version;
|
||||
PipelineVersion pipeline_version;
|
||||
};
|
||||
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
|
||||
|
||||
struct BlockGemm
|
||||
{
|
||||
BlockGemmPipelineVersion pipeline_version;
|
||||
BlockGemmPipelineScheduler scheduler;
|
||||
PipelineVersion pipeline_version;
|
||||
PipelineScheduler scheduler;
|
||||
};
|
||||
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
|
||||
|
||||
@@ -156,7 +156,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GemmSpecialization gemm_specialization;
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
size_t num_groups_to_merge;
|
||||
LoopScheduler loop_scheduler;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
@@ -191,7 +191,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
LoopScheduler loop_scheduler;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
|
||||
@@ -16,7 +16,7 @@ using namespace test;
|
||||
// Common test implementation
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
BlockGemmPipelineVersion FwdPipelineVersion,
|
||||
PipelineVersion FwdPipelineVersion,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
|
||||
{
|
||||
@@ -52,7 +52,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion,
|
||||
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
@@ -73,13 +73,13 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
|
||||
|
||||
// Verify pipeline version is correct
|
||||
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
|
||||
if(FwdPipelineVersion == PipelineVersion::V1)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
|
||||
else if(FwdPipelineVersion == PipelineVersion::V3)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
|
||||
else if(FwdPipelineVersion == PipelineVersion::V4)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
|
||||
else if(FwdPipelineVersion == PipelineVersion::V5)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
|
||||
|
||||
// Verify specialization is correct
|
||||
@@ -140,7 +140,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle()
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.num_groups_to_merge = 2,
|
||||
.loop_scheduler = LoopScheduler::DEFAULT};
|
||||
.loop_scheduler = PipelineScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -176,7 +176,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
|
||||
.n_per_wmma = 32,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1,
|
||||
.pipeline_version = GridwiseGemmPipelineVersion::V1};
|
||||
.pipeline_version = PipelineVersion::V1};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
@@ -209,7 +209,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.loop_scheduler = LoopScheduler::DEFAULT};
|
||||
.loop_scheduler = PipelineScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user