From d20bdfd88be98f643444cef10b97f2485c01dddf Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Sun, 15 Feb 2026 10:43:28 +0000 Subject: [PATCH] added conv traits to bwd data wmma and wmma v3 instances --- .../builder/reflect/conv_description.hpp | 8 +- .../ck_tile/builder/reflect/conv_traits.hpp | 4 +- ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 48 +++ ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 53 +++ ..._conv_bwd_data_multiple_d_xdl_cshuffle.hpp | 60 +++ ..._bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 5 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 5 +- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 3 +- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 3 +- .../builder/reflect/conv_traits_helpers.hpp | 5 +- .../reflect/instance_to_conv_traits.hpp | 5 + ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 2 +- ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 2 +- .../builder/test/conv/ck/test_conv_traits.cpp | 402 +++++++++++++++++- 14 files changed, 591 insertions(+), 14 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 5c09e4b735..875265d4ab 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -90,6 +90,10 @@ class ConvDescription : public Description 2, "Gemm padding: ", traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT)); else f.writeLine(2, "Struct does not contain optional gemm_padding argument"); + if(traits_.do_pad_gemm_m) + f.writeLine(2, "Do Padd Gemm M: ", traits_.do_pad_gemm_m.value_or(false)); + if(traits_.do_pad_gemm_n) + f.writeLine(2, "Do Padd Gemm N: ", traits_.do_pad_gemm_n.value_or(false)); f.writeLine(2, "Convolution specialization: ", traits_.conv_specialization); // Pipeline section f.writeLine(2, "Pipeline version: ", traits_.pipeline_version); @@ -215,10 +219,10 @@ class ConvDescription : public Description f.writeLine(2, "Struct does not contain optional " "max_transpose_transfer_src_scalar_per_vector parameter"); - if(traits_.max_transpose_dst_scalar_per_vector) + if(traits_.max_transpose_transfer_dst_scalar_per_vector) f.writeLine(2, "Max Transpose dst scalar per vector: ", - traits_.max_transpose_dst_scalar_per_vector.value_or(0)); + traits_.max_transpose_transfer_dst_scalar_per_vector.value_or(0)); else f.writeLine( 2, diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 16a9c47f7e..21f6525534 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -108,8 +108,10 @@ struct ConvTraits builder::PipelineScheduler pipeline_scheduler; std::optional max_transpose_transfer_src_scalar_per_vector = std::nullopt; - std::optional max_transpose_dst_scalar_per_vector = std::nullopt; + std::optional max_transpose_transfer_dst_scalar_per_vector = std::nullopt; std::optional num_groups_to_merge = std::nullopt; + std::optional do_pad_gemm_m = std::nullopt; + std::optional do_pad_gemm_n = std::nullopt; }; } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000..81ca13e2aa --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer( + InstTraits::kCDEShuffleBlockTransferScalarPerVector_NPerBlock), + .num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..45757c0432 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Wmma_CShuffle_V3_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kAK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kBK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer( + InstTraits::kCDEShuffleBlockTransferScalarPerVector_NPerBlock[0]), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kMaxTransposeTransferSrcScalarPerVector, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kMaxTransposeTransferDstScalarPerVector, + .do_pad_gemm_m = InstTraits::kDoPadGemmM, + .do_pad_gemm_n = InstTraits::kDoPadGemmN, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000..50fcc9b192 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,60 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdData_multiple_d_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kAK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kBK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kMaxTransposeTransferSrcScalarPerVector, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kMaxTransposeTransferDstScalarPerVector, + .do_pad_gemm_m = InstTraits::kDoPadGemmM, + .do_pad_gemm_n = InstTraits::kDoPadGemmN, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 4f39b00b5c..0ce714adcc 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -42,8 +42,9 @@ constexpr ConvTraits instance_to_conv_traits() .pipeline_scheduler = get_pipeline_scheduler(), .max_transpose_transfer_src_scalar_per_vector = InstTraits::kTransposeTransferSrcScalarPerVector, - .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, - .num_groups_to_merge = InstTraits::kNumGroupsToMerge, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kTransposeTransferDstScalarPerVector, + .num_groups_to_merge = InstTraits::kNumGroupsToMerge, }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 5666233091..ba663c12bb 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -49,8 +49,9 @@ constexpr ConvTraits instance_to_conv_traits() .pipeline_scheduler = get_pipeline_scheduler(), .max_transpose_transfer_src_scalar_per_vector = InstTraits::kTransposeTransferSrcScalarPerVector, - .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, - .num_groups_to_merge = InstTraits::kNumGroupsToMerge, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kTransposeTransferDstScalarPerVector, + .num_groups_to_merge = InstTraits::kNumGroupsToMerge, }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 13625aa182..81a7bf76fd 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -42,7 +42,8 @@ constexpr ConvTraits instance_to_conv_traits() .pipeline_scheduler = get_pipeline_scheduler(), .max_transpose_transfer_src_scalar_per_vector = InstTraits::kMaxTransposeTransferSrcScalarPerVector, - .max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kMaxTransposeTransferDstScalarPerVector, }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 39fde33217..d47b2ee4d3 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -49,7 +49,8 @@ constexpr ConvTraits instance_to_conv_traits() .pipeline_scheduler = get_pipeline_scheduler(), .max_transpose_transfer_src_scalar_per_vector = InstTraits::kMaxTransposeTransferSrcScalarPerVector, - .max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector, + .max_transpose_transfer_dst_scalar_per_vector = + InstTraits::kMaxTransposeTransferDstScalarPerVector, }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp index 4baf2423ee..3b6c006e6b 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp @@ -796,7 +796,8 @@ constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params() } template -constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer() +constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer( + ck::index_t CDEBlockTansferScalarPerVector = InstTraits::kCDEBlockTransferScalarPerVector) { return OutputTileTransferInfo{ .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle, @@ -805,7 +806,7 @@ constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer() InstTraits::kCDEThreadClusterLengths[1], InstTraits::kCDEThreadClusterLengths[2], InstTraits::kCDEThreadClusterLengths[3]}, - .scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector}; + .scalar_per_vector = CDEBlockTansferScalarPerVector}; } template diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index e10baaf712..cb4b3b2175 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -18,3 +18,8 @@ #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" + +// Bwd data instances +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 0346768f9c..4698f81366 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -243,7 +243,7 @@ struct InstanceTraits::value; // Static member function to generate instance string diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index 4c70af06f3..40274f8c08 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -270,7 +270,7 @@ struct InstanceTraits< using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; - static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; using ComputeTypeA = ComputeTypeA_; diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index 7de7fae92d..0d8a9ca13a 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -19,6 +19,9 @@ #include #include #include +#include +#include +#include namespace { @@ -35,7 +38,392 @@ class ConvTraitsTest : public ::testing::Test { }; -// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 +// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle +TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleWmmaTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::half_t, // OutDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + float, // OutComputeType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Default, // ConvBackwardDataSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWMMA + 32, // NPerWMMA + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMRepeatPerWavePerShuffle + 1, // CShuffleNRepeatPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + 2, // NumGemmKPrefetchStage + ck::LoopScheduler::Default, // BlkGemmPipeSched + ck::PipelineVersion::v1>; // PipelineVerison + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP32); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + EXPECT_EQ(traits.num_gemm_k_prefetch_stage, 2); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleWmmaV3TraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::half_t, // OutDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + float, // OutComputeType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Default, // ConvBackwardDataSpecialization + false, // DoPadGemmM + false, // DoPadGemmN + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerWMMA + 32, // NPerWMMA + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMRepeatPerWavePerShuffle + 1, // CShuffleNRepeatPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + ck::Sequence<8, 8, 8>, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP32); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + EXPECT_FALSE(traits.do_pad_gemm_n.value()); + EXPECT_FALSE(traits.do_pad_gemm_m.value()); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleXDLTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::half_t, // OutDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + float, // OutComputeType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Default, // ConvBackwardDataSpecialization + false, // DoPadGemmM + false, // DoPadGemmN + 1, // num_gemm_k_prefetch_stage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::LoopScheduler::Default, // BlkGemmPipeSched + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP32); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.num_gemm_k_prefetch_stage, 1); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + EXPECT_FALSE(traits.do_pad_gemm_n.value()); + EXPECT_FALSE(traits.do_pad_gemm_m.value()); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_Wmma_CShuffle TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaTraitsExtraction) { // Define a concrete instance type with specific template parameters @@ -270,6 +658,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaV3TraitsExtraction) EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); + // Verify pipeline configuration } @@ -516,6 +907,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageWmmaCshuffleTraitsExtraction) // Verify pipeline configuration EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); + + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); } // Test ConvTraits with DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffleV3 @@ -640,6 +1034,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageXdlCshuffleTraitsExtraction) // Verify pipeline configuration EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); + + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); } // Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle @@ -1001,6 +1398,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightXdlCshuffleTraitsExtraction) // Verify pipeline configuration EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); + + EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1); + EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1); } // test conv traits device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp