mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
refactored helpers to support bwd conv
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag>
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag> constexpr ConvTraits
|
||||
instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1),
|
||||
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1),
|
||||
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
|
||||
.num_gemm_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
/*
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
|
||||
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1,
|
||||
InstTraits::kK0PerBlock), .b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock), .warp_gemm =
|
||||
conv_traits_xdl_warp_gemm_params<InstTraits>(), .c_tile_transfer = { .shuffle_params =
|
||||
{.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, .n_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleNXdlPerWavePerShuffle}, .thread_cluster_dims =
|
||||
{InstTraits::kCThreadClusterLengths[0], InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
} */
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -24,7 +24,7 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.layout = fwd_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
@@ -33,8 +33,8 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
|
||||
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1),
|
||||
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
|
||||
@@ -24,7 +24,7 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.layout = fwd_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
@@ -33,8 +33,8 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
|
||||
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1),
|
||||
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
|
||||
@@ -22,60 +22,21 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = {.m = InstTraits::kMPerBlock,
|
||||
.n = InstTraits::kNPerBlock,
|
||||
.k = InstTraits::kKPerBlock},
|
||||
.a_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kK1},
|
||||
.transfer_params = {.k1 = InstTraits::kK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
|
||||
.b_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kK1},
|
||||
.transfer_params = {.k1 = InstTraits::kK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
|
||||
.warp_gemm = {.gemm_m = InstTraits::kMPerWmma,
|
||||
.gemm_n = InstTraits::kNPerWmma,
|
||||
.m_iter = InstTraits::kMRepeat,
|
||||
.n_iter = InstTraits::kNRepeat},
|
||||
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleMRepeatPerShuffle,
|
||||
.n_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleNRepeatPerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0],
|
||||
InstTraits::kCDEThreadClusterLengths[1],
|
||||
InstTraits::kCDEThreadClusterLengths[2],
|
||||
InstTraits::kCDEThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector},
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = fwd_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1),
|
||||
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1),
|
||||
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
|
||||
.num_gemm_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
|
||||
@@ -24,7 +24,7 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = conv_layout<Instance>(),
|
||||
.layout = fwd_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<Instance>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
@@ -33,8 +33,8 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
|
||||
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1),
|
||||
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
|
||||
@@ -31,10 +31,10 @@
|
||||
/// 1. **Concepts**: Type trait concepts for checking if instance types have required members
|
||||
/// - Layout concepts (HasFwdConvLayouts)
|
||||
/// - Specialization concepts (HasGemmSpec)
|
||||
/// - Data type concepts (HasDataTypes)
|
||||
/// - Data type concepts (HasFwdConvDataTypes)
|
||||
/// - Operation concepts (HasElementwiseOps)
|
||||
/// - Tile parameter concepts (HasTileParams)
|
||||
/// - Composite concepts (IsXdlFwdConv, HasConvTraits)
|
||||
/// - Composite concepts (IsFwdConv, HasConvTraits)
|
||||
///
|
||||
/// 2. **Enum Conversions**: Functions to convert CK enums to builder enums
|
||||
/// - Pipeline version conversions (BlockGemmPipelineVersion, PipelineVersion)
|
||||
@@ -64,6 +64,14 @@ concept HasFwdConvLayouts = requires {
|
||||
typename T::ELayout;
|
||||
};
|
||||
|
||||
// Backwards weight layout concept - checks for In, wei and out layouts
|
||||
template <typename T>
|
||||
concept HasBwdWeiLayouts = requires {
|
||||
typename T::InLayout;
|
||||
typename T::WeiLayout;
|
||||
typename T::OutLayout;
|
||||
};
|
||||
|
||||
// GEMM specialization concept - checks for kGemmSpecialization member
|
||||
template <typename T>
|
||||
concept HasGemmSpec = requires {
|
||||
@@ -74,7 +82,10 @@ concept HasGemmSpec = requires {
|
||||
|
||||
// Data types concept - checks for ADataType member
|
||||
template <typename T>
|
||||
concept HasDataTypes = requires { typename T::ADataType; };
|
||||
concept HasFwdConvDataTypes = requires { typename T::ADataType; };
|
||||
|
||||
template <typename T>
|
||||
concept HasBwdDataTypes = requires { typename T::InDataType; };
|
||||
|
||||
// Elementwise operations concept - checks for A/B/CDE elementwise operation types
|
||||
template <typename T>
|
||||
@@ -98,14 +109,17 @@ concept HasTileParams = requires {
|
||||
// Comprehensive concept that checks if an instance has all XDL forward convolution traits
|
||||
// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions
|
||||
template <typename T>
|
||||
concept IsXdlFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasDataTypes<T> &&
|
||||
concept IsFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasFwdConvDataTypes<T> &&
|
||||
HasElementwiseOps<T> && HasTileParams<T>;
|
||||
|
||||
template <typename T>
|
||||
concept IsBwdWeiConv = HasBwdWeiLayouts<T> && HasGemmSpec<T> && HasBwdDataTypes<T> &&
|
||||
HasElementwiseOps<T> && HasTileParams<T>;
|
||||
|
||||
// Primary concept for checking if a type can be described
|
||||
// Currently only forward convolutions are supported, but this can be extended
|
||||
// in the future to include backward data and backward weight convolutions
|
||||
|
||||
template <typename T>
|
||||
concept HasConvTraits = IsXdlFwdConv<InstanceTraits<T>>;
|
||||
concept HasConvTraits = IsFwdConv<InstanceTraits<T>> || IsBwdWeiConv<InstanceTraits<T>>;
|
||||
|
||||
// ============================================================================
|
||||
// SECTION 2: ENUM CONVERSIONS
|
||||
@@ -319,26 +333,17 @@ template <typename A, typename B, typename E, int SpatialDim>
|
||||
"Check the conv_layout() function for the list of supported layout combinations.";
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
/// index 0 -> Input layout
|
||||
/// index 1 -> Weight layout
|
||||
/// index 2 -> Output layout
|
||||
template <typename Instance>
|
||||
template <typename A, typename B, typename E, int kSpatialDim>
|
||||
constexpr auto conv_layout()
|
||||
requires HasFwdConvLayouts<InstanceTraits<Instance>>
|
||||
{
|
||||
|
||||
// Helper lambda to construct layout array
|
||||
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
|
||||
|
||||
using A = typename InstanceTraits<Instance>::ALayout;
|
||||
using B = typename InstanceTraits<Instance>::BLayout;
|
||||
using E = typename InstanceTraits<Instance>::ELayout;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using enum builder::TensorLayout;
|
||||
|
||||
switch(InstanceTraits<Instance>::kSpatialDim)
|
||||
switch(kSpatialDim)
|
||||
{
|
||||
case 1:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
|
||||
@@ -382,12 +387,47 @@ constexpr auto conv_layout()
|
||||
|
||||
// If we reach here, the layout combination is not supported
|
||||
// Call consteval function to trigger a compile-time error with a clear message
|
||||
report_unsupported_layout_error<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
|
||||
report_unsupported_layout_error<A, B, E, kSpatialDim>();
|
||||
|
||||
// This return is unreachable but needed to satisfy the compiler
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
/// index 0 -> Input layout
|
||||
/// index 1 -> Weight layout
|
||||
/// index 2 -> Output layout
|
||||
|
||||
template <typename Instance>
|
||||
constexpr auto fwd_conv_layout()
|
||||
requires HasFwdConvLayouts<InstanceTraits<Instance>>
|
||||
{
|
||||
|
||||
using A = typename InstanceTraits<Instance>::ALayout;
|
||||
using B = typename InstanceTraits<Instance>::BLayout;
|
||||
using E = typename InstanceTraits<Instance>::ELayout;
|
||||
return conv_layout<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
/// index 0 -> Input layout
|
||||
/// index 1 -> Weight layout
|
||||
/// index 2 -> Output layout
|
||||
template <typename Instance>
|
||||
constexpr auto bwd_wei_conv_layout()
|
||||
requires HasBwdWeiLayouts<InstanceTraits<Instance>>
|
||||
{
|
||||
|
||||
using A = typename InstanceTraits<Instance>::InLayout;
|
||||
using B = typename InstanceTraits<Instance>::WeiLayout;
|
||||
using E = typename InstanceTraits<Instance>::OutLayout;
|
||||
return conv_layout<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Data Types
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -410,7 +450,7 @@ template <typename ADataType>
|
||||
/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8).
|
||||
template <typename Instance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
requires HasDataTypes<InstanceTraits<Instance>>
|
||||
requires HasFwdConvDataTypes<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ADataType = typename InstTraits::ADataType;
|
||||
@@ -640,14 +680,18 @@ constexpr auto get_pipeline_scheduler()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SECTION 4: Helper functions for common structures
|
||||
// ============================================================================
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr InputTileTransferInfo conv_traits_xdl_a_transfer_params()
|
||||
constexpr InputTileTransferInfo conv_traits_a_transfer_params(int _k1)
|
||||
{
|
||||
return InputTileTransferInfo{
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / _k1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.k1 = _k1},
|
||||
.transfer_params = {.k1 = _k1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
@@ -659,13 +703,13 @@ constexpr InputTileTransferInfo conv_traits_xdl_a_transfer_params()
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr InputTileTransferInfo conv_traits_xdl_b_transfer_params()
|
||||
constexpr InputTileTransferInfo conv_traits_b_transfer_params(int _k1)
|
||||
{
|
||||
return InputTileTransferInfo{
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / _k1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.k1 = _k1},
|
||||
.transfer_params = {.k1 = _k1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
@@ -695,11 +739,11 @@ constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer()
|
||||
return OutputTileTransferInfo{
|
||||
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
|
||||
.thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0],
|
||||
InstTraits::kCDEThreadClusterLengths[1],
|
||||
InstTraits::kCDEThreadClusterLengths[2],
|
||||
InstTraits::kCDEThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
@@ -721,10 +765,9 @@ constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params()
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr DataTileInfo conv_traits_data_tile()
|
||||
constexpr DataTileInfo conv_traits_data_tile(int k_or_k0 = InstTraits::kKPerBlock)
|
||||
{
|
||||
return DataTileInfo{
|
||||
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
|
||||
return DataTileInfo{.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = k_or_k0};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
|
||||
@@ -8,5 +8,7 @@
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
|
||||
// Wmma instances
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
@@ -60,6 +60,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle;
|
||||
|
||||
namespace ck_tile {
|
||||
namespace reflect {
|
||||
/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle device kernel
|
||||
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout_,
|
||||
@@ -152,7 +156,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
{
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
|
||||
|
||||
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
@@ -204,6 +209,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
|
||||
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
|
||||
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl;
|
||||
|
||||
@@ -224,7 +232,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kNDimSpatial; // 1. NDimSpatial
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
|
||||
|
||||
@@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;
|
||||
namespace ck_tile {
|
||||
namespace reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 device kernel
|
||||
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
@@ -150,6 +155,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
ComputeTypeA_,
|
||||
ComputeTypeB_>>
|
||||
{
|
||||
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag;
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
|
||||
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
|
||||
@@ -204,6 +212,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
|
||||
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl;
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using ck_tile::builder::ConvDirection;
|
||||
@@ -27,7 +29,129 @@ class ConvTraitsTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(ConvTraitsTest, ConvFwdTraitsWmmaExtraction)
|
||||
// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleDXdlCshuffle
|
||||
/*TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::half_t, // InDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default, //
|
||||
ConvBackwardWeightSpecialization 256, //
|
||||
BlockSize 128, // MPerBlock 128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
1, // MaxTransposeTransferSrcScalarPerVector
|
||||
1>; // MaxTransposeTransferDstScalarPerVector>
|
||||
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(traits.spatial_dim, 2);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(traits.data_type, DataType::FP16);
|
||||
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(std::get<ck_tile::builder::ConvFwdSpecialization>(traits.conv_specialization),
|
||||
ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(traits.thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(traits.tile_dims.m, 128);
|
||||
EXPECT_EQ(traits.tile_dims.n, 128);
|
||||
EXPECT_EQ(traits.tile_dims.k, 16);
|
||||
|
||||
// Verify A tile transfer info
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify B tile transfer info
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify warp GEMM params
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
|
||||
|
||||
// Verify output tile transfer info
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
}*/
|
||||
|
||||
// test conv traits device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
|
||||
TEST_F(ConvTraitsTest, ConvFwdTraitsMultipleDCshuffleWmmaExtraction)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
||||
|
||||
Reference in New Issue
Block a user