mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
added reflection for grouped_conv_bwd weight_cshuffleV3
This commit is contained in:
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_V3_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer =
|
||||
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
|
||||
|
||||
// Wmma instances
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
@@ -160,7 +160,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag;
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
|
||||
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
@@ -175,7 +175,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
using WeiElementwiseOperation = WeiElementwiseOperation_;
|
||||
using OutElementwiseOperation = OutElementwiseOperation_;
|
||||
|
||||
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
@@ -190,22 +190,40 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_K1;
|
||||
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
|
||||
|
||||
static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
|
||||
static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
|
||||
@@ -213,7 +231,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl;
|
||||
|
||||
@@ -232,7 +250,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kNDimSpatial; // 1. NDimSpatial
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
|
||||
@@ -250,30 +268,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_weight_spec_name(
|
||||
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 16. K0PerBlock
|
||||
oss << "," << kK1; // 17. K1
|
||||
oss << "," << kMPerXDL; // 18. MPerXDL
|
||||
oss << "," << kNPerXDL; // 19. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
|
||||
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
|
||||
oss << "," << kBlockSize; // 13. BlockSize
|
||||
oss << "," << kMPerBlock; // 14. MPerBlock
|
||||
oss << "," << kNPerBlock; // 15. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 16. K0PerBlock
|
||||
oss << "," << kK1; // 17. K1
|
||||
oss << "," << kMPerXDL; // 18. MPerXDL
|
||||
oss << "," << kNPerXDL; // 19. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 25.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
|
||||
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
|
||||
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 32.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
|
||||
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
|
||||
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
|
||||
oss << ","
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -30,7 +31,129 @@ class ConvTraitsTest : public ::testing::Test
|
||||
};
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleDXdlCshuffle
|
||||
TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDTraitsExtraction)
|
||||
TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleV3TraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::half_t, // InDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
|
||||
Default, // ConvBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t>; // BComputeDataType
|
||||
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(traits.spatial_dim, 2);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(traits.data_type, DataType::FP16);
|
||||
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(std::get<ck_tile::builder::ConvBwdWeightSpecialization>(traits.conv_specialization),
|
||||
ck_tile::builder::ConvBwdWeightSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(traits.thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(traits.tile_dims.m, 128);
|
||||
EXPECT_EQ(traits.tile_dims.n, 128);
|
||||
EXPECT_EQ(traits.tile_dims.k, 16);
|
||||
|
||||
// Verify A tile transfer info
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify B tile transfer info
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify warp GEMM params
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
|
||||
|
||||
// Verify output tile transfer info
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleDXdlCshuffle
|
||||
TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
|
||||
Reference in New Issue
Block a user