mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4582 (commit 990a00d)
[CK_Builder] added bwd data kernels to builder factory (#4582) This PR adds bwd data wmma and xdl kernels to the ck builder, their instance and conv traits as well as tests for the above.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c8a8449eec
commit
5e06874aae
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "gmock/gmock.h"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_DATA,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
|
||||
.with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdData_2DFp16_MultiD_Wmma_CShuffle_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWK,GKYXC,EmptyTuple,GNHWC",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16"}); // check compute types
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "gmock/gmock.h"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_DATA,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParamsABK1_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
|
||||
.with_gemm_pad_params(0, 0)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
|
||||
.with_transpose_params(2, 2);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdData_2DFp16_MultiD_Wmma_CShuffle_V3_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWK,GKYXC,EmptyTuple,GNHWC",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16"}); // check compute types
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "gmock/gmock.h"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_DATA,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(cku::BwdDataGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::Transfer_4x64x1)
|
||||
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
|
||||
.with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_gemm_pad_params(0, 0)
|
||||
.with_transpose_params(2, 2);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdData_2DFp16_MultiD_Xdl_CShuffle_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWK,GKYXC,EmptyTuple,GNHWC",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16"}); // check compute types
|
||||
}
|
||||
@@ -19,6 +19,9 @@
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -35,7 +38,390 @@ class ConvTraitsTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3
|
||||
// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleWmmaTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::half_t, // OutDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
float, // OutComputeType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Default, // ConvBackwardDataSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerWMMA
|
||||
32, // NPerWMMA
|
||||
4, // MRepeat
|
||||
4, // NRepeat
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerWavePerShuffle
|
||||
1, // CShuffleNRepeatPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
2, // NumGemmKPrefetchStage
|
||||
ck::LoopScheduler::Default, // BlkGemmPipeSched
|
||||
ck::PipelineVersion::v1>; // PipelineVerison
|
||||
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(traits.spatial_dim, 2);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(traits.data_type, DataType::FP32);
|
||||
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(traits.thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(traits.tile_dims.m, 128);
|
||||
EXPECT_EQ(traits.tile_dims.n, 128);
|
||||
EXPECT_EQ(traits.tile_dims.k, 16);
|
||||
|
||||
// Verify A tile transfer info
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify B tile transfer info
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify warp GEMM params
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
|
||||
|
||||
// Verify output tile transfer info
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.num_gemm_k_prefetch_stage, 2);
|
||||
|
||||
// Verify pipeline configuration
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3
|
||||
TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleWmmaV3TraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::half_t, // OutDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
float, // OutComputeType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Default, // ConvBackwardDataSpecialization
|
||||
false, // DoPadGemmM
|
||||
false, // DoPadGemmN
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerWMMA
|
||||
32, // NPerWMMA
|
||||
4, // MRepeat
|
||||
4, // NRepeat
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerWavePerShuffle
|
||||
1, // CShuffleNRepeatPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
ck::Sequence<8, 8, 8>, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
1, // MaxTransposeTransferSrcScalarPerVector
|
||||
1>; // MaxTransposeTransferDstScalarPerVector
|
||||
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(traits.spatial_dim, 2);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(traits.data_type, DataType::FP32);
|
||||
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(traits.thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(traits.tile_dims.m, 128);
|
||||
EXPECT_EQ(traits.tile_dims.n, 128);
|
||||
EXPECT_EQ(traits.tile_dims.k, 16);
|
||||
|
||||
// Verify A tile transfer info
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
|
||||
EXPECT_FALSE(traits.do_pad_gemm_n.value());
|
||||
EXPECT_FALSE(traits.do_pad_gemm_m.value());
|
||||
|
||||
// Verify B tile transfer info
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify warp GEMM params
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
|
||||
|
||||
// Verify output tile transfer info
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
|
||||
// Verify pipeline configuration
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle
|
||||
TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleXDLTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::half_t, // OutDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
float, // OutComputeType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Default, // ConvBackwardDataSpecialization
|
||||
false, // DoPadGemmM
|
||||
false, // DoPadGemmN
|
||||
1, // num_gemm_k_prefetch_stage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
ck::LoopScheduler::Default, // BlkGemmPipeSched
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
1, // MaxTransposeTransferSrcScalarPerVector
|
||||
1>; // MaxTransposeTransferDstScalarPerVector
|
||||
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(traits.spatial_dim, 2);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(traits.data_type, DataType::FP32);
|
||||
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
EXPECT_EQ(traits.num_gemm_k_prefetch_stage, 1);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(traits.thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(traits.tile_dims.m, 128);
|
||||
EXPECT_EQ(traits.tile_dims.n, 128);
|
||||
EXPECT_EQ(traits.tile_dims.k, 16);
|
||||
|
||||
// Verify A tile transfer info
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
|
||||
EXPECT_FALSE(traits.do_pad_gemm_n.value());
|
||||
EXPECT_FALSE(traits.do_pad_gemm_m.value());
|
||||
|
||||
// Verify B tile transfer info
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify warp GEMM params
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
|
||||
|
||||
// Verify output tile transfer info
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
@@ -270,6 +656,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaV3TraitsExtraction)
|
||||
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
|
||||
// Verify pipeline configuration
|
||||
}
|
||||
|
||||
@@ -516,6 +905,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageWmmaCshuffleTraitsExtraction)
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffleV3
|
||||
@@ -640,6 +1032,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageXdlCshuffleTraitsExtraction)
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
@@ -1001,6 +1396,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightXdlCshuffleTraitsExtraction)
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
}
|
||||
|
||||
// test conv traits device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
|
||||
|
||||
Reference in New Issue
Block a user