diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp new file mode 100644 index 0000000000..2538c48372 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag> +std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag> constexpr ConvTraits +instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(), + .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kK1), + .b_tile_transfer = conv_traits_b_transfer_params(InstTraits::kK1), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .num_gemm_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp new file mode 100644 index 0000000000..8087ea1b36 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { +/* +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kK1, +InstTraits::kK0PerBlock), .b_tile_transfer = +conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), .warp_gemm = +conv_traits_xdl_warp_gemm_params(), .c_tile_transfer = { .shuffle_params = +{.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, .n_gemms_per_shuffle = +InstTraits::kCShuffleNXdlPerWavePerShuffle}, .thread_cluster_dims = +{InstTraits::kCThreadClusterLengths[0], InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} */ + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index c42d2abf80..4e3e5a5e9a 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -24,7 +24,7 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), + .layout = fwd_conv_layout(), .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), @@ -33,8 +33,8 @@ constexpr ConvTraits instance_to_conv_traits() .conv_specialization = conv_spec(), .thread_block_size = InstTraits::kBlockSize, .tile_dims = conv_traits_data_tile(), - .a_tile_transfer = conv_traits_xdl_a_transfer_params(), - .b_tile_transfer = conv_traits_xdl_b_transfer_params(), + .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kAK1), + .b_tile_transfer = conv_traits_b_transfer_params(InstTraits::kBK1), .warp_gemm = conv_traits_xdl_warp_gemm_params(), .c_tile_transfer = conv_traits_xdl_c_tile_transfer(), .pipeline_version = get_pipeline_version(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index b3d13fb337..32cf532ccf 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -24,7 +24,7 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), + .layout = fwd_conv_layout(), .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), @@ -33,8 +33,8 @@ constexpr ConvTraits instance_to_conv_traits() .conv_specialization = conv_spec(), .thread_block_size = InstTraits::kBlockSize, .tile_dims = conv_traits_data_tile(), - .a_tile_transfer = conv_traits_xdl_a_transfer_params(), - .b_tile_transfer = conv_traits_xdl_b_transfer_params(), + .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kAK1), + .b_tile_transfer = conv_traits_b_transfer_params(InstTraits::kBK1), .warp_gemm = conv_traits_xdl_warp_gemm_params(), .c_tile_transfer = conv_traits_xdl_c_tile_transfer(), .pipeline_version = get_pipeline_version(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 846f3899d4..9c907ae875 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -22,60 +22,21 @@ constexpr ConvTraits instance_to_conv_traits() using InstTraits = InstanceTraits; return ConvTraits{ - .spatial_dim = InstTraits::kSpatialDim, - .direction = conv_direction(), - .layout = conv_layout(), - .data_type = conv_data_type(), - .input_element_op = elementwise_op(), - .weight_element_op = elementwise_op(), - .output_element_op = elementwise_op(), - .gemm_padding = gemm_spec(), - .conv_specialization = conv_spec(), - .thread_block_size = InstTraits::kBlockSize, - .tile_dims = {.m = InstTraits::kMPerBlock, - .n = InstTraits::kNPerBlock, - .k = InstTraits::kKPerBlock}, - .a_tile_transfer = - {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kK1, - .m_or_n = InstTraits::kMPerBlock, - .k1 = InstTraits::kK1}, - .transfer_params = {.k1 = InstTraits::kK1, - .thread_cluster_dims = InstTraits::kAThreadClusterLengths, - .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, - .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, - .src_scalar_per_vector = - InstTraits::kABlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kABlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, - .b_tile_transfer = - {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kK1, - .m_or_n = InstTraits::kNPerBlock, - .k1 = InstTraits::kK1}, - .transfer_params = {.k1 = InstTraits::kK1, - .thread_cluster_dims = InstTraits::kBThreadClusterLengths, - .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, - .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, - .src_scalar_per_vector = - InstTraits::kBBlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kBBlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, - .warp_gemm = {.gemm_m = InstTraits::kMPerWmma, - .gemm_n = InstTraits::kNPerWmma, - .m_iter = InstTraits::kMRepeat, - .n_iter = InstTraits::kNRepeat}, - .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = - InstTraits::kCShuffleMRepeatPerShuffle, - .n_gemms_per_shuffle = - InstTraits::kCShuffleNRepeatPerShuffle}, - .thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0], - InstTraits::kCDEThreadClusterLengths[1], - InstTraits::kCDEThreadClusterLengths[2], - InstTraits::kCDEThreadClusterLengths[3]}, - .scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector}, + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(), + .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kK1), + .b_tile_transfer = conv_traits_b_transfer_params(InstTraits::kK1), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), .num_gemm_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, .pipeline_version = get_pipeline_version(), .pipeline_scheduler = get_pipeline_scheduler(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index cf417ad959..719d3cb6d5 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -24,7 +24,7 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), + .layout = fwd_conv_layout(), .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), @@ -33,8 +33,8 @@ constexpr ConvTraits instance_to_conv_traits() .conv_specialization = conv_spec(), .thread_block_size = InstTraits::kBlockSize, .tile_dims = conv_traits_data_tile(), - .a_tile_transfer = conv_traits_xdl_a_transfer_params(), - .b_tile_transfer = conv_traits_xdl_b_transfer_params(), + .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kAK1), + .b_tile_transfer = conv_traits_b_transfer_params(InstTraits::kBK1), .warp_gemm = conv_traits_xdl_warp_gemm_params(), .c_tile_transfer = conv_traits_xdl_c_tile_transfer(), .pipeline_version = get_pipeline_version(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp index 60d11bce6e..6644c2ddd8 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp @@ -31,10 +31,10 @@ /// 1. **Concepts**: Type trait concepts for checking if instance types have required members /// - Layout concepts (HasFwdConvLayouts) /// - Specialization concepts (HasGemmSpec) -/// - Data type concepts (HasDataTypes) +/// - Data type concepts (HasFwdConvDataTypes) /// - Operation concepts (HasElementwiseOps) /// - Tile parameter concepts (HasTileParams) -/// - Composite concepts (IsXdlFwdConv, HasConvTraits) +/// - Composite concepts (IsFwdConv, HasConvTraits) /// /// 2. **Enum Conversions**: Functions to convert CK enums to builder enums /// - Pipeline version conversions (BlockGemmPipelineVersion, PipelineVersion) @@ -64,6 +64,14 @@ concept HasFwdConvLayouts = requires { typename T::ELayout; }; +// Backwards weight layout concept - checks for In, wei and out layouts +template +concept HasBwdWeiLayouts = requires { + typename T::InLayout; + typename T::WeiLayout; + typename T::OutLayout; +}; + // GEMM specialization concept - checks for kGemmSpecialization member template concept HasGemmSpec = requires { @@ -74,7 +82,10 @@ concept HasGemmSpec = requires { // Data types concept - checks for ADataType member template -concept HasDataTypes = requires { typename T::ADataType; }; +concept HasFwdConvDataTypes = requires { typename T::ADataType; }; + +template +concept HasBwdDataTypes = requires { typename T::InDataType; }; // Elementwise operations concept - checks for A/B/CDE elementwise operation types template @@ -98,14 +109,17 @@ concept HasTileParams = requires { // Comprehensive concept that checks if an instance has all XDL forward convolution traits // This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions template -concept IsXdlFwdConv = HasFwdConvLayouts && HasGemmSpec && HasDataTypes && +concept IsFwdConv = HasFwdConvLayouts && HasGemmSpec && HasFwdConvDataTypes && + HasElementwiseOps && HasTileParams; + +template +concept IsBwdWeiConv = HasBwdWeiLayouts && HasGemmSpec && HasBwdDataTypes && HasElementwiseOps && HasTileParams; // Primary concept for checking if a type can be described -// Currently only forward convolutions are supported, but this can be extended -// in the future to include backward data and backward weight convolutions + template -concept HasConvTraits = IsXdlFwdConv>; +concept HasConvTraits = IsFwdConv> || IsBwdWeiConv>; // ============================================================================ // SECTION 2: ENUM CONVERSIONS @@ -319,26 +333,17 @@ template "Check the conv_layout() function for the list of supported layout combinations."; } -/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. -/// @tparam Instance The device kernel instance type. -/// @return An std::array corresponding to the tensor layouts: -/// index 0 -> Input layout -/// index 1 -> Weight layout -/// index 2 -> Output layout -template +template constexpr auto conv_layout() - requires HasFwdConvLayouts> { + // Helper lambda to construct layout array auto layouts = [](auto... Ls) { return std::array{Ls...}; }; - using A = typename InstanceTraits::ALayout; - using B = typename InstanceTraits::BLayout; - using E = typename InstanceTraits::ELayout; namespace ctl = ck::tensor_layout::convolution; using enum builder::TensorLayout; - switch(InstanceTraits::kSpatialDim) + switch(kSpatialDim) { case 1: if constexpr(layouts_are) @@ -382,12 +387,47 @@ constexpr auto conv_layout() // If we reach here, the layout combination is not supported // Call consteval function to trigger a compile-time error with a clear message - report_unsupported_layout_error::kSpatialDim>(); + report_unsupported_layout_error(); // This return is unreachable but needed to satisfy the compiler return layouts(GNHWC, GKYXC, GNHWK); } +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return An std::array corresponding to the tensor layouts: +/// index 0 -> Input layout +/// index 1 -> Weight layout +/// index 2 -> Output layout + +template +constexpr auto fwd_conv_layout() + requires HasFwdConvLayouts> +{ + + using A = typename InstanceTraits::ALayout; + using B = typename InstanceTraits::BLayout; + using E = typename InstanceTraits::ELayout; + return conv_layout::kSpatialDim>(); +} + +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return An std::array corresponding to the tensor layouts: +/// index 0 -> Input layout +/// index 1 -> Weight layout +/// index 2 -> Output layout +template +constexpr auto bwd_wei_conv_layout() + requires HasBwdWeiLayouts> +{ + + using A = typename InstanceTraits::InLayout; + using B = typename InstanceTraits::WeiLayout; + using E = typename InstanceTraits::OutLayout; + return conv_layout::kSpatialDim>(); +} + // ---------------------------------------------------------------------------- // Data Types // ---------------------------------------------------------------------------- @@ -410,7 +450,7 @@ template /// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8). template constexpr builder::DataType conv_data_type() - requires HasDataTypes> + requires HasFwdConvDataTypes> { using InstTraits = InstanceTraits; using ADataType = typename InstTraits::ADataType; @@ -640,14 +680,18 @@ constexpr auto get_pipeline_scheduler() } } +// ============================================================================ +// SECTION 4: Helper functions for common structures +// ============================================================================ + template -constexpr InputTileTransferInfo conv_traits_xdl_a_transfer_params() +constexpr InputTileTransferInfo conv_traits_a_transfer_params(int _k1) { return InputTileTransferInfo{ - .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / _k1, .m_or_n = InstTraits::kMPerBlock, - .k1 = InstTraits::kAK1}, - .transfer_params = {.k1 = InstTraits::kAK1, + .k1 = _k1}, + .transfer_params = {.k1 = _k1, .thread_cluster_dims = InstTraits::kAThreadClusterLengths, .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, @@ -659,13 +703,13 @@ constexpr InputTileTransferInfo conv_traits_xdl_a_transfer_params() } template -constexpr InputTileTransferInfo conv_traits_xdl_b_transfer_params() +constexpr InputTileTransferInfo conv_traits_b_transfer_params(int _k1) { return InputTileTransferInfo{ - .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / _k1, .m_or_n = InstTraits::kNPerBlock, - .k1 = InstTraits::kBK1}, - .transfer_params = {.k1 = InstTraits::kBK1, + .k1 = _k1}, + .transfer_params = {.k1 = _k1, .thread_cluster_dims = InstTraits::kBThreadClusterLengths, .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, @@ -695,11 +739,11 @@ constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer() return OutputTileTransferInfo{ .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle, .n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle}, - .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], - InstTraits::kCThreadClusterLengths[1], - InstTraits::kCThreadClusterLengths[2], - InstTraits::kCThreadClusterLengths[3]}, - .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; + .thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0], + InstTraits::kCDEThreadClusterLengths[1], + InstTraits::kCDEThreadClusterLengths[2], + InstTraits::kCDEThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector}; } template @@ -721,10 +765,9 @@ constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params() } template -constexpr DataTileInfo conv_traits_data_tile() +constexpr DataTileInfo conv_traits_data_tile(int k_or_k0 = InstTraits::kKPerBlock) { - return DataTileInfo{ - .m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock}; + return DataTileInfo{.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = k_or_k0}; } } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index d8b48fbc9c..90db3e89e6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -8,5 +8,7 @@ #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" + // Wmma instances #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 2c893b9c1d..29b0d3f0d3 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -60,6 +60,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +{ +}; template ::value; + static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = CBlockTransferScalarPerVector_NWaveNPerXdl; @@ -224,7 +232,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 147028f9cf..516766e72f 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 device kernel +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag +{ +}; + template > { + + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag; static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; static constexpr ck::index_t kNDimSpatial = NDimSpatial; @@ -204,6 +212,8 @@ struct InstanceTraits::value; static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = CBlockTransferScalarPerVector_NWaveNPerXdl; diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index 499062a3a3..de8b7c2dcc 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -12,6 +12,8 @@ #include #include +#include + namespace { using ck_tile::builder::ConvDirection; @@ -27,7 +29,129 @@ class ConvTraitsTest : public ::testing::Test { }; -TEST_F(ConvTraitsTest, ConvFwdTraitsWmmaExtraction) +// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleDXdlCshuffle +/*TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default, // +ConvBackwardWeightSpecialization 256, // +BlockSize 128, // MPerBlock 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(std::get(traits.conv_specialization), + ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +}*/ + +// test conv traits device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +TEST_F(ConvTraitsTest, ConvFwdTraitsMultipleDCshuffleWmmaExtraction) { using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<