refactored helpers to support bwd conv

This commit is contained in:
Kevin Abraham
2026-01-15 10:01:05 +00:00
parent 9b1c8ae951
commit 6abcb1d5cf
11 changed files with 350 additions and 103 deletions

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag>
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag> constexpr ConvTraits
instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(),
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1),
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.num_gemm_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,51 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/*
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1,
InstTraits::kK0PerBlock), .b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock), .warp_gemm =
conv_traits_xdl_warp_gemm_params<InstTraits>(), .c_tile_transfer = { .shuffle_params =
{.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, .n_gemms_per_shuffle =
InstTraits::kCShuffleNXdlPerWavePerShuffle}, .thread_cluster_dims =
{InstTraits::kCThreadClusterLengths[0], InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
} */
} // namespace ck_tile::reflect::conv

View File

@@ -24,7 +24,7 @@ constexpr ConvTraits instance_to_conv_traits()
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
@@ -33,8 +33,8 @@ constexpr ConvTraits instance_to_conv_traits()
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(),
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1),
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),

View File

@@ -24,7 +24,7 @@ constexpr ConvTraits instance_to_conv_traits()
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
@@ -33,8 +33,8 @@ constexpr ConvTraits instance_to_conv_traits()
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(),
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1),
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),

View File

@@ -22,60 +22,21 @@ constexpr ConvTraits instance_to_conv_traits()
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = {.m = InstTraits::kMPerBlock,
.n = InstTraits::kNPerBlock,
.k = InstTraits::kKPerBlock},
.a_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kK1},
.transfer_params = {.k1 = InstTraits::kK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
.b_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kK1},
.transfer_params = {.k1 = InstTraits::kK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
.warp_gemm = {.gemm_m = InstTraits::kMPerWmma,
.gemm_n = InstTraits::kNPerWmma,
.m_iter = InstTraits::kMRepeat,
.n_iter = InstTraits::kNRepeat},
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
InstTraits::kCShuffleMRepeatPerShuffle,
.n_gemms_per_shuffle =
InstTraits::kCShuffleNRepeatPerShuffle},
.thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0],
InstTraits::kCDEThreadClusterLengths[1],
InstTraits::kCDEThreadClusterLengths[2],
InstTraits::kCDEThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector},
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(),
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1),
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.num_gemm_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),

View File

@@ -24,7 +24,7 @@ constexpr ConvTraits instance_to_conv_traits()
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
@@ -33,8 +33,8 @@ constexpr ConvTraits instance_to_conv_traits()
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(),
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1),
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),

View File

@@ -31,10 +31,10 @@
/// 1. **Concepts**: Type trait concepts for checking if instance types have required members
/// - Layout concepts (HasFwdConvLayouts)
/// - Specialization concepts (HasGemmSpec)
/// - Data type concepts (HasDataTypes)
/// - Data type concepts (HasFwdConvDataTypes)
/// - Operation concepts (HasElementwiseOps)
/// - Tile parameter concepts (HasTileParams)
/// - Composite concepts (IsXdlFwdConv, HasConvTraits)
/// - Composite concepts (IsFwdConv, HasConvTraits)
///
/// 2. **Enum Conversions**: Functions to convert CK enums to builder enums
/// - Pipeline version conversions (BlockGemmPipelineVersion, PipelineVersion)
@@ -64,6 +64,14 @@ concept HasFwdConvLayouts = requires {
typename T::ELayout;
};
// Backwards weight layout concept - checks for In, wei and out layouts
template <typename T>
concept HasBwdWeiLayouts = requires {
typename T::InLayout;
typename T::WeiLayout;
typename T::OutLayout;
};
// GEMM specialization concept - checks for kGemmSpecialization member
template <typename T>
concept HasGemmSpec = requires {
@@ -74,7 +82,10 @@ concept HasGemmSpec = requires {
// Data types concept - checks for ADataType member
template <typename T>
concept HasDataTypes = requires { typename T::ADataType; };
concept HasFwdConvDataTypes = requires { typename T::ADataType; };
template <typename T>
concept HasBwdDataTypes = requires { typename T::InDataType; };
// Elementwise operations concept - checks for A/B/CDE elementwise operation types
template <typename T>
@@ -98,14 +109,17 @@ concept HasTileParams = requires {
// Comprehensive concept that checks if an instance has all XDL forward convolution traits
// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions
template <typename T>
concept IsXdlFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasDataTypes<T> &&
concept IsFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasFwdConvDataTypes<T> &&
HasElementwiseOps<T> && HasTileParams<T>;
template <typename T>
concept IsBwdWeiConv = HasBwdWeiLayouts<T> && HasGemmSpec<T> && HasBwdDataTypes<T> &&
HasElementwiseOps<T> && HasTileParams<T>;
// Primary concept for checking if a type can be described
// Currently only forward convolutions are supported, but this can be extended
// in the future to include backward data and backward weight convolutions
template <typename T>
concept HasConvTraits = IsXdlFwdConv<InstanceTraits<T>>;
concept HasConvTraits = IsFwdConv<InstanceTraits<T>> || IsBwdWeiConv<InstanceTraits<T>>;
// ============================================================================
// SECTION 2: ENUM CONVERSIONS
@@ -319,26 +333,17 @@ template <typename A, typename B, typename E, int SpatialDim>
"Check the conv_layout() function for the list of supported layout combinations.";
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
/// index 0 -> Input layout
/// index 1 -> Weight layout
/// index 2 -> Output layout
template <typename Instance>
template <typename A, typename B, typename E, int kSpatialDim>
constexpr auto conv_layout()
requires HasFwdConvLayouts<InstanceTraits<Instance>>
{
// Helper lambda to construct layout array
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
using A = typename InstanceTraits<Instance>::ALayout;
using B = typename InstanceTraits<Instance>::BLayout;
using E = typename InstanceTraits<Instance>::ELayout;
namespace ctl = ck::tensor_layout::convolution;
using enum builder::TensorLayout;
switch(InstanceTraits<Instance>::kSpatialDim)
switch(kSpatialDim)
{
case 1:
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
@@ -382,12 +387,47 @@ constexpr auto conv_layout()
// If we reach here, the layout combination is not supported
// Call consteval function to trigger a compile-time error with a clear message
report_unsupported_layout_error<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
report_unsupported_layout_error<A, B, E, kSpatialDim>();
// This return is unreachable but needed to satisfy the compiler
return layouts(GNHWC, GKYXC, GNHWK);
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
/// index 0 -> Input layout
/// index 1 -> Weight layout
/// index 2 -> Output layout
template <typename Instance>
constexpr auto fwd_conv_layout()
requires HasFwdConvLayouts<InstanceTraits<Instance>>
{
using A = typename InstanceTraits<Instance>::ALayout;
using B = typename InstanceTraits<Instance>::BLayout;
using E = typename InstanceTraits<Instance>::ELayout;
return conv_layout<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
/// index 0 -> Input layout
/// index 1 -> Weight layout
/// index 2 -> Output layout
template <typename Instance>
constexpr auto bwd_wei_conv_layout()
requires HasBwdWeiLayouts<InstanceTraits<Instance>>
{
using A = typename InstanceTraits<Instance>::InLayout;
using B = typename InstanceTraits<Instance>::WeiLayout;
using E = typename InstanceTraits<Instance>::OutLayout;
return conv_layout<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
}
// ----------------------------------------------------------------------------
// Data Types
// ----------------------------------------------------------------------------
@@ -410,7 +450,7 @@ template <typename ADataType>
/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8).
template <typename Instance>
constexpr builder::DataType conv_data_type()
requires HasDataTypes<InstanceTraits<Instance>>
requires HasFwdConvDataTypes<InstanceTraits<Instance>>
{
using InstTraits = InstanceTraits<Instance>;
using ADataType = typename InstTraits::ADataType;
@@ -640,14 +680,18 @@ constexpr auto get_pipeline_scheduler()
}
}
// ============================================================================
// SECTION 4: Helper functions for common structures
// ============================================================================
template <typename InstTraits>
constexpr InputTileTransferInfo conv_traits_xdl_a_transfer_params()
constexpr InputTileTransferInfo conv_traits_a_transfer_params(int _k1)
{
return InputTileTransferInfo{
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / _k1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.k1 = _k1},
.transfer_params = {.k1 = _k1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
@@ -659,13 +703,13 @@ constexpr InputTileTransferInfo conv_traits_xdl_a_transfer_params()
}
template <typename InstTraits>
constexpr InputTileTransferInfo conv_traits_xdl_b_transfer_params()
constexpr InputTileTransferInfo conv_traits_b_transfer_params(int _k1)
{
return InputTileTransferInfo{
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / _k1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.k1 = _k1},
.transfer_params = {.k1 = _k1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
@@ -695,11 +739,11 @@ constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer()
return OutputTileTransferInfo{
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
.thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0],
InstTraits::kCDEThreadClusterLengths[1],
InstTraits::kCDEThreadClusterLengths[2],
InstTraits::kCDEThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector};
}
template <typename InstTraits>
@@ -721,10 +765,9 @@ constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params()
}
template <typename InstTraits>
constexpr DataTileInfo conv_traits_data_tile()
constexpr DataTileInfo conv_traits_data_tile(int k_or_k0 = InstTraits::kKPerBlock)
{
return DataTileInfo{
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
return DataTileInfo{.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = k_or_k0};
}
} // namespace ck_tile::reflect::conv

View File

@@ -8,5 +8,7 @@
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
// Wmma instances
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"

View File

@@ -60,6 +60,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle device kernel
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
@@ -152,7 +156,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -204,6 +209,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
CBlockTransferScalarPerVector_NWaveNPerXdl;
@@ -224,7 +232,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout

View File

@@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 device kernel
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -150,6 +155,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
ComputeTypeA_,
ComputeTypeB_>>
{
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag;
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
@@ -204,6 +212,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
CBlockTransferScalarPerVector_NWaveNPerXdl;

View File

@@ -12,6 +12,8 @@
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp>
namespace {
using ck_tile::builder::ConvDirection;
@@ -27,7 +29,129 @@ class ConvTraitsTest : public ::testing::Test
{
};
TEST_F(ConvTraitsTest, ConvFwdTraitsWmmaExtraction)
// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleDXdlCshuffle
/*TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // InLayout
ck::tensor_layout::convolution::GKYXC, // WeiLayout
ck::tensor_layout::convolution::GNHWK, // OutLayout
ck::half_t, // InDataType
ck::half_t, // WeiDataType
ck::half_t, // OutDataType
float, // AccDataType
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default, //
ConvBackwardWeightSpecialization 256, //
BlockSize 128, // MPerBlock 128, // NPerBlock
16, // K0PerBlock
8, // K1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
1, // ABlockLdsAddExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
1, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
8, // CDEBlockTransferScalarPerVector_NPerBlock_
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
1, // MaxTransposeTransferSrcScalarPerVector
1>; // MaxTransposeTransferDstScalarPerVector>
// Use ConvTraitsTmpl to extract compile-time information
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(traits.spatial_dim, 2);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(std::get<ck_tile::builder::ConvFwdSpecialization>(traits.conv_specialization),
ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(traits.thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(traits.tile_dims.m, 128);
EXPECT_EQ(traits.tile_dims.n, 128);
EXPECT_EQ(traits.tile_dims.k, 16);
// Verify A tile transfer info
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
// Verify B tile transfer info
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
// Verify warp GEMM params
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
// Verify output tile transfer info
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
// Verify pipeline configuration
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE);
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
}*/
// test conv traits device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
TEST_F(ConvTraitsTest, ConvFwdTraitsMultipleDCshuffleWmmaExtraction)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<