mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-22 16:17:37 +00:00
[CK_BUILDER] Add DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 to CK Builder (#5284) Add factory, InstanceTraits, and conv traits support for the WMMA V3 forward convolution kernel, enabling the CK Builder to generate and dispatch this kernel variant used by MIOpen on gfx11/gfx12 GPUs. ## Motivation As reported in issue #4944, MIOpen includes WMMA V3 forward convolution kernels, so this PR adds support for those kernels similarly to other supported kernels. ## Technical Details This follows the same implementation as the other kernels. I added some support for reflection, but I left a few todos since we need to generalize our convolution traits to generalize across WMMA/MFMA and CK/CKTile. ## Test Plan Added faster tests to `ninja smoke-builder` that check the instance-traits logic, and I added longer tests that instantiate kernels, following the existing pattern in other kernals. ## Test Result I tested all code with `ninja check-builder` on a gfx1101 build and ran on gfx1101. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
963 lines
61 KiB
C++
963 lines
61 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
#include "ck/ck.hpp"
|
|
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
|
#include "ck/utility/reduction_operator.hpp"
|
|
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
|
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
|
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
|
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
|
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
|
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
|
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
|
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp"
|
|
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
|
|
|
namespace {
|
|
|
|
using ::testing::ElementsAre;
|
|
|
|
// NOTE: The V3ExtractsAllFieldsCorrectly test below performs detailed field extraction testing
|
|
// for the V3 variant as a reference implementation. For new InstanceTraits specializations,
|
|
// only the instance_string() functionality needs to be tested. Each new specialization should have:
|
|
// 1. A test using instance_string<T>() directly (in this file)
|
|
// 2. A test using GetInstanceString() through base class pointer (in separate
|
|
// test_get_instance_string_*.cpp file) This prevents test duplication while ensuring both access
|
|
// methods work correctly.
|
|
TEST(InstanceTraits, V3ExtractsAllFieldsCorrectly)
|
|
{
|
|
// Define a concrete instance type with specific template parameters
|
|
using DeviceInstance =
|
|
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
|
2, // NDimSpatial
|
|
ck::tensor_layout::convolution::GNHWC, // ALayout
|
|
ck::tensor_layout::convolution::GKYXC, // BLayout
|
|
ck::Tuple<>, // DsLayout
|
|
ck::tensor_layout::convolution::GNHWK, // ELayout
|
|
ck::half_t, // ADataType
|
|
ck::half_t, // BDataType
|
|
float, // AccDataType
|
|
ck::half_t, // CShuffleDataType
|
|
ck::Tuple<>, // DsDataType
|
|
ck::half_t, // EDataType
|
|
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
|
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
|
Default, // ConvForwardSpecialization
|
|
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
|
256, // BlockSize
|
|
128, // MPerBlock
|
|
128, // NPerBlock
|
|
16, // KPerBlock
|
|
8, // AK1
|
|
8, // BK1
|
|
32, // MPerXDL
|
|
32, // NPerXDL
|
|
4, // MXdlPerWave
|
|
4, // NXdlPerWave
|
|
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
|
2, // ABlockTransferSrcVectorDim
|
|
8, // ABlockTransferSrcScalarPerVector
|
|
8, // ABlockTransferDstScalarPerVector_AK1
|
|
1, // ABlockLdsExtraM
|
|
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
|
2, // BBlockTransferSrcVectorDim
|
|
8, // BBlockTransferSrcScalarPerVector
|
|
8, // BBlockTransferDstScalarPerVector_BK1
|
|
1, // BBlockLdsExtraN
|
|
1, // CShuffleMXdlPerWavePerShuffle
|
|
1, // CShuffleNXdlPerWavePerShuffle
|
|
ck::Sequence<1,
|
|
32,
|
|
1,
|
|
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
|
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
|
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
|
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
|
ck::half_t, // AComputeDataType
|
|
ck::half_t, // BComputeDataType
|
|
false, // DirectLoad
|
|
1>; // NumGroupsToMerge
|
|
|
|
// Use InstanceTraits to extract compile-time information
|
|
using Traits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
|
|
|
|
// Verify spatial dimension
|
|
EXPECT_EQ(Traits::kSpatialDim, 2);
|
|
|
|
// Verify block configuration
|
|
EXPECT_EQ(Traits::kBlockSize, 256);
|
|
EXPECT_EQ(Traits::kMPerBlock, 128);
|
|
EXPECT_EQ(Traits::kNPerBlock, 128);
|
|
EXPECT_EQ(Traits::kKPerBlock, 16);
|
|
|
|
// Verify tuning parameters
|
|
EXPECT_EQ(Traits::kAK1, 8);
|
|
EXPECT_EQ(Traits::kBK1, 8);
|
|
EXPECT_EQ(Traits::kMPerXDL, 32);
|
|
EXPECT_EQ(Traits::kNPerXDL, 32);
|
|
EXPECT_EQ(Traits::kMXdlPerWave, 4);
|
|
EXPECT_EQ(Traits::kNXdlPerWave, 4);
|
|
|
|
// Verify A block transfer parameters
|
|
EXPECT_EQ(Traits::kABlockTransferSrcVectorDim, 2);
|
|
EXPECT_EQ(Traits::kABlockTransferSrcScalarPerVector, 8);
|
|
EXPECT_EQ(Traits::kABlockTransferDstScalarPerVectorK1, 8);
|
|
EXPECT_EQ(Traits::kABlockLdsExtraM, 1);
|
|
|
|
// Verify B block transfer parameters
|
|
EXPECT_EQ(Traits::kBBlockTransferSrcVectorDim, 2);
|
|
EXPECT_EQ(Traits::kBBlockTransferSrcScalarPerVector, 8);
|
|
EXPECT_EQ(Traits::kBBlockTransferDstScalarPerVectorK1, 8);
|
|
EXPECT_EQ(Traits::kBBlockLdsExtraN, 1);
|
|
|
|
// Verify C shuffle parameters
|
|
EXPECT_EQ(Traits::kCShuffleMXdlPerWavePerShuffle, 1);
|
|
EXPECT_EQ(Traits::kCShuffleNXdlPerWavePerShuffle, 1);
|
|
EXPECT_EQ(Traits::kCBlockTransferScalarPerVector, 8);
|
|
|
|
// Verify pipeline configuration
|
|
EXPECT_EQ(Traits::kPipelineScheduler, ck::BlockGemmPipelineScheduler::Intrawave);
|
|
EXPECT_EQ(Traits::kPipelineVersion, ck::BlockGemmPipelineVersion::v1);
|
|
|
|
// Verify data types using std::is_same
|
|
EXPECT_TRUE((std::is_same<Traits::ADataType, ck::half_t>::value));
|
|
EXPECT_TRUE((std::is_same<Traits::BDataType, ck::half_t>::value));
|
|
EXPECT_TRUE((std::is_same<Traits::AccDataType, float>::value));
|
|
EXPECT_TRUE((std::is_same<Traits::EDataType, ck::half_t>::value));
|
|
|
|
// Verify layout types
|
|
EXPECT_TRUE((std::is_same<Traits::ALayout, ck::tensor_layout::convolution::GNHWC>::value));
|
|
EXPECT_TRUE((std::is_same<Traits::BLayout, ck::tensor_layout::convolution::GKYXC>::value));
|
|
EXPECT_TRUE((std::is_same<Traits::ELayout, ck::tensor_layout::convolution::GNHWK>::value));
|
|
|
|
// Verify all array values for thread cluster lengths using googlemock matchers
|
|
EXPECT_THAT(Traits::kAThreadClusterLengths, ElementsAre(4, 64, 1));
|
|
EXPECT_THAT(Traits::kBThreadClusterLengths, ElementsAre(4, 64, 1));
|
|
EXPECT_THAT(Traits::kCThreadClusterLengths, ElementsAre(1, 32, 1, 8));
|
|
|
|
// Verify A block transfer arrange order and access order arrays
|
|
EXPECT_THAT(Traits::kAThreadClusterArrangeOrder, ElementsAre(1, 0, 2));
|
|
EXPECT_THAT(Traits::kABlockTransferSrcAccessOrder, ElementsAre(1, 0, 2));
|
|
|
|
// Verify B block transfer arrange order and access order arrays
|
|
EXPECT_THAT(Traits::kBThreadClusterArrangeOrder, ElementsAre(1, 0, 2));
|
|
EXPECT_THAT(Traits::kBBlockTransferSrcAccessOrder, ElementsAre(1, 0, 2));
|
|
|
|
// Verify additional data types
|
|
EXPECT_TRUE((std::is_same<Traits::CShuffleDataType, ck::half_t>::value));
|
|
EXPECT_TRUE((std::is_same<Traits::DsDataType, ck::Tuple<>>::value));
|
|
EXPECT_TRUE((std::is_same<Traits::AComputeDataType, ck::half_t>::value));
|
|
EXPECT_TRUE((std::is_same<Traits::BComputeDataType, ck::half_t>::value));
|
|
|
|
// Verify additional layout types
|
|
EXPECT_TRUE((std::is_same<Traits::DsLayout, ck::Tuple<>>::value));
|
|
|
|
// Verify element-wise operations
|
|
EXPECT_TRUE((std::is_same<Traits::AElementwiseOperation,
|
|
ck::tensor_operation::element_wise::PassThrough>::value));
|
|
EXPECT_TRUE((std::is_same<Traits::BElementwiseOperation,
|
|
ck::tensor_operation::element_wise::PassThrough>::value));
|
|
EXPECT_TRUE((std::is_same<Traits::CDEElementwiseOperation,
|
|
ck::tensor_operation::element_wise::PassThrough>::value));
|
|
}
|
|
|
|
TEST(InstanceTraits, V3InstanceStringReturnsCorrectFormat)
|
|
{
|
|
// Define a concrete instance type with specific template parameters
|
|
using DeviceInstance =
|
|
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
|
2, // NDimSpatial
|
|
ck::tensor_layout::convolution::GNHWC, // ALayout
|
|
ck::tensor_layout::convolution::GKYXC, // BLayout
|
|
ck::Tuple<>, // DsLayout
|
|
ck::tensor_layout::convolution::GNHWK, // ELayout
|
|
ck::half_t, // ADataType
|
|
ck::half_t, // BDataType
|
|
float, // AccDataType
|
|
ck::half_t, // CShuffleDataType
|
|
ck::Tuple<>, // DsDataType
|
|
ck::half_t, // EDataType
|
|
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
|
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
|
Default, // ConvForwardSpecialization
|
|
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
|
256, // BlockSize
|
|
128, // MPerBlock
|
|
128, // NPerBlock
|
|
16, // KPerBlock
|
|
8, // AK1
|
|
8, // BK1
|
|
32, // MPerXDL
|
|
32, // NPerXDL
|
|
4, // MXdlPerWave
|
|
4, // NXdlPerWave
|
|
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
|
2, // ABlockTransferSrcVectorDim
|
|
8, // ABlockTransferSrcScalarPerVector
|
|
8, // ABlockTransferDstScalarPerVector_AK1
|
|
1, // ABlockLdsExtraM
|
|
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
|
2, // BBlockTransferSrcVectorDim
|
|
8, // BBlockTransferSrcScalarPerVector
|
|
8, // BBlockTransferDstScalarPerVector_BK1
|
|
1, // BBlockLdsExtraN
|
|
1, // CShuffleMXdlPerWavePerShuffle
|
|
1, // CShuffleNXdlPerWavePerShuffle
|
|
ck::Sequence<1,
|
|
32,
|
|
1,
|
|
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
|
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
|
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
|
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
|
ck::half_t, // AComputeDataType
|
|
ck::half_t, // BComputeDataType
|
|
false, // DirectLoad
|
|
1>; // NumGroupsToMerge
|
|
|
|
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
|
|
|
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
|
|
"<2" // NDimSpatial
|
|
",GNHWC" // ALayout
|
|
",GKYXC" // BLayout
|
|
",EmptyTuple" // DsLayout
|
|
",GNHWK" // ELayout
|
|
",fp16" // ADataType
|
|
",fp16" // BDataType
|
|
",fp32" // AccDataType
|
|
",fp16" // CShuffleDataType
|
|
",EmptyTuple" // DsDataType
|
|
",fp16" // EDataType
|
|
",PassThrough" // AElementwiseOperation
|
|
",PassThrough" // BElementwiseOperation
|
|
",PassThrough" // CDEElementwiseOperation
|
|
",Default" // ConvForwardSpecialization
|
|
",Default" // GemmSpec
|
|
",256" // BlockSize
|
|
",128" // MPerBlock
|
|
",128" // NPerBlock
|
|
",16" // KPerBlock
|
|
",8" // AK1
|
|
",8" // BK1
|
|
",32" // MPerXDL
|
|
",32" // NPerXDL
|
|
",4" // MXdlPerWave
|
|
",4" // NXdlPerWave
|
|
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
|
",2" // ABlockTransferSrcVectorDim
|
|
",8" // ABlockTransferSrcScalarPerVector
|
|
",8" // ABlockTransferDstScalarPerVector_AK1
|
|
",true" // ABlockLdsExtraM
|
|
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
|
",2" // BBlockTransferSrcVectorDim
|
|
",8" // BBlockTransferSrcScalarPerVector
|
|
",8" // BBlockTransferDstScalarPerVector_BK1
|
|
",true" // BBlockLdsExtraN
|
|
",1" // CShuffleMXdlPerWavePerShuffle
|
|
",1" // CShuffleNXdlPerWavePerShuffle
|
|
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
|
",8" // CDEBlockTransferScalarPerVector_NPerBlock
|
|
",Intrawave" // BlkGemmPipeSched
|
|
",v1" // BlkGemmPipelineVer
|
|
",fp16" // AComputeDataType
|
|
",fp16" // BComputeDataType
|
|
",false" // DirectLoad
|
|
",1>"; // NumGroupsToMerge
|
|
|
|
EXPECT_EQ(instance_str, expected_str);
|
|
}
|
|
|
|
TEST(InstanceTraits, BaseInstanceStringReturnsCorrectFormat)
|
|
{
|
|
using DeviceInstance =
|
|
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
|
2, // NDimSpatial
|
|
ck::tensor_layout::convolution::GNHWC, // ALayout
|
|
ck::tensor_layout::convolution::GKYXC, // BLayout
|
|
ck::Tuple<>, // DsLayout
|
|
ck::tensor_layout::convolution::GNHWK, // ELayout
|
|
ck::half_t, // ADataType
|
|
ck::half_t, // BDataType
|
|
float, // AccDataType
|
|
ck::half_t, // CShuffleDataType
|
|
ck::Tuple<>, // DsDataType
|
|
ck::half_t, // EDataType
|
|
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
|
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
|
Default, // ConvForwardSpecialization
|
|
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
|
1, // NumGemmKPrefetchStage
|
|
256, // BlockSize
|
|
128, // MPerBlock
|
|
128, // NPerBlock
|
|
16, // KPerBlock
|
|
8, // AK1
|
|
8, // BK1
|
|
32, // MPerXDL
|
|
32, // NPerXDL
|
|
4, // MXdlPerWave
|
|
4, // NXdlPerWave
|
|
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
|
2, // ABlockTransferSrcVectorDim
|
|
8, // ABlockTransferSrcScalarPerVector
|
|
8, // ABlockTransferDstScalarPerVector_AK1
|
|
1, // ABlockLdsExtraM
|
|
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
|
2, // BBlockTransferSrcVectorDim
|
|
8, // BBlockTransferSrcScalarPerVector
|
|
8, // BBlockTransferDstScalarPerVector_BK1
|
|
1, // BBlockLdsExtraN
|
|
1, // CShuffleMXdlPerWavePerShuffle
|
|
1, // CShuffleNXdlPerWavePerShuffle
|
|
ck::Sequence<1,
|
|
32,
|
|
1,
|
|
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
|
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
|
ck::half_t, // AComputeDataType
|
|
ck::half_t, // BComputeDataType
|
|
ck::LoopScheduler::Default, // LoopSched
|
|
1>; // NumGroupsToMerge
|
|
|
|
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
|
|
|
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
|
"<2" // NDimSpatial
|
|
",GNHWC" // ALayout
|
|
",GKYXC" // BLayout
|
|
",EmptyTuple" // DsLayout
|
|
",GNHWK" // ELayout
|
|
",fp16" // ADataType
|
|
",fp16" // BDataType
|
|
",fp32" // AccDataType
|
|
",fp16" // CShuffleDataType
|
|
",EmptyTuple" // DsDataType
|
|
",fp16" // EDataType
|
|
",PassThrough" // AElementwiseOperation
|
|
",PassThrough" // BElementwiseOperation
|
|
",PassThrough" // CDEElementwiseOperation
|
|
",Default" // ConvForwardSpecialization
|
|
",Default" // GemmSpec
|
|
",1" // NumGemmKPrefetchStage
|
|
",256" // BlockSize
|
|
",128" // MPerBlock
|
|
",128" // NPerBlock
|
|
",16" // KPerBlock
|
|
",8" // AK1
|
|
",8" // BK1
|
|
",32" // MPerXDL
|
|
",32" // NPerXDL
|
|
",4" // MXdlPerWave
|
|
",4" // NXdlPerWave
|
|
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
|
",2" // ABlockTransferSrcVectorDim
|
|
",8" // ABlockTransferSrcScalarPerVector
|
|
",8" // ABlockTransferDstScalarPerVector_AK1
|
|
",true" // ABlockLdsExtraM
|
|
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
|
",2" // BBlockTransferSrcVectorDim
|
|
",8" // BBlockTransferSrcScalarPerVector
|
|
",8" // BBlockTransferDstScalarPerVector_BK1
|
|
",true" // BBlockLdsExtraN
|
|
",1" // CShuffleMXdlPerWavePerShuffle
|
|
",1" // CShuffleNXdlPerWavePerShuffle
|
|
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
|
",8" // CDEBlockTransferScalarPerVector_NPerBlock
|
|
",fp16" // AComputeDataType
|
|
",fp16" // BComputeDataType
|
|
",Default" // LoopSched
|
|
",1>"; // NumGroupsToMerge
|
|
|
|
EXPECT_EQ(instance_str, expected_str);
|
|
}
|
|
|
|
TEST(InstanceTraits, LargeTensorInstanceStringReturnsCorrectFormat)
|
|
{
|
|
using DeviceInstance =
|
|
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
|
2, // NDimSpatial
|
|
ck::tensor_layout::convolution::GNHWC, // ALayout
|
|
ck::tensor_layout::convolution::GKYXC, // BLayout
|
|
ck::Tuple<>, // DsLayout
|
|
ck::tensor_layout::convolution::GNHWK, // ELayout
|
|
ck::half_t, // ADataType
|
|
ck::half_t, // BDataType
|
|
float, // AccDataType
|
|
ck::half_t, // CShuffleDataType
|
|
ck::Tuple<>, // DsDataType
|
|
ck::half_t, // EDataType
|
|
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
|
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
|
Default, // ConvForwardSpecialization
|
|
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
|
1, // NumGemmKPrefetchStage
|
|
256, // BlockSize
|
|
128, // MPerBlock
|
|
128, // NPerBlock
|
|
16, // KPerBlock
|
|
8, // AK1
|
|
8, // BK1
|
|
32, // MPerXDL
|
|
32, // NPerXDL
|
|
4, // MXdlPerWave
|
|
4, // NXdlPerWave
|
|
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
|
2, // ABlockTransferSrcVectorDim
|
|
8, // ABlockTransferSrcScalarPerVector
|
|
8, // ABlockTransferDstScalarPerVector_AK1
|
|
1, // ABlockLdsExtraM
|
|
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
|
2, // BBlockTransferSrcVectorDim
|
|
8, // BBlockTransferSrcScalarPerVector
|
|
8, // BBlockTransferDstScalarPerVector_BK1
|
|
1, // BBlockLdsExtraN
|
|
1, // CShuffleMXdlPerWavePerShuffle
|
|
1, // CShuffleNXdlPerWavePerShuffle
|
|
ck::Sequence<1,
|
|
32,
|
|
1,
|
|
8>, // CDEBlockTransferClusterLengths
|
|
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
|
ck::half_t, // AComputeDataType
|
|
ck::half_t, // BComputeDataType
|
|
ck::LoopScheduler::Default>; // LoopSched
|
|
|
|
// Generate instance string
|
|
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
|
|
|
// Expected string with all 48 template parameters
|
|
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
|
|
"<2" // NDimSpatial
|
|
",GNHWC" // ALayout
|
|
",GKYXC" // BLayout
|
|
",EmptyTuple" // DsLayout
|
|
",GNHWK" // ELayout
|
|
",fp16" // ADataType
|
|
",fp16" // BDataType
|
|
",fp32" // AccDataType
|
|
",fp16" // CShuffleDataType
|
|
",EmptyTuple" // DsDataType
|
|
",fp16" // EDataType
|
|
",PassThrough" // AElementwiseOperation
|
|
",PassThrough" // BElementwiseOperation
|
|
",PassThrough" // CDEElementwiseOperation
|
|
",Default" // ConvForwardSpecialization
|
|
",Default" // GemmSpec
|
|
",1" // NumGemmKPrefetchStage
|
|
",256" // BlockSize
|
|
",128" // MPerBlock
|
|
",128" // NPerBlock
|
|
",16" // KPerBlock
|
|
",8" // AK1
|
|
",8" // BK1
|
|
",32" // MPerXDL
|
|
",32" // NPerXDL
|
|
",4" // MXdlPerWave
|
|
",4" // NXdlPerWave
|
|
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
|
",2" // ABlockTransferSrcVectorDim
|
|
",8" // ABlockTransferSrcScalarPerVector
|
|
",8" // ABlockTransferDstScalarPerVector_AK1
|
|
",true" // ABlockLdsExtraM
|
|
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
|
",2" // BBlockTransferSrcVectorDim
|
|
",8" // BBlockTransferSrcScalarPerVector
|
|
",8" // BBlockTransferDstScalarPerVector_BK1
|
|
",true" // BBlockLdsExtraN
|
|
",1" // CShuffleMXdlPerWavePerShuffle
|
|
",1" // CShuffleNXdlPerWavePerShuffle
|
|
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
|
",8" // CDEBlockTransferScalarPerVector_NPerBlock
|
|
",fp16" // AComputeDataType
|
|
",fp16" // BComputeDataType
|
|
",Default>"; // LoopSched
|
|
|
|
// Verify the generated string matches exactly
|
|
EXPECT_EQ(instance_str, expected_str);
|
|
}
|
|
|
|
TEST(InstanceTraits, WmmaInstanceStringReturnsCorrectFormat)
|
|
{
|
|
using DeviceInstance =
|
|
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
|
2, // NDimSpatial
|
|
ck::tensor_layout::convolution::GNHWC, // ALayout
|
|
ck::tensor_layout::convolution::GKYXC, // BLayout
|
|
ck::Tuple<>, // DsLayout
|
|
ck::tensor_layout::convolution::GNHWK, // ELayout
|
|
ck::half_t, // ADataType
|
|
ck::half_t, // BDataType
|
|
float, // AccDataType
|
|
ck::half_t, // CShuffleDataType
|
|
ck::Tuple<>, // DsDataType
|
|
ck::half_t, // EDataType
|
|
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
|
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
|
Default, // ConvForwardSpecialization
|
|
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
|
|
1, // NumGemmKPrefetchStage
|
|
128, // BlockSize
|
|
64, // MPerBlock
|
|
64, // NPerBlock
|
|
32, // KPerBlock
|
|
8, // K1
|
|
16, // MPerWmma
|
|
16, // NPerWmma
|
|
2, // MRepeat
|
|
2, // NRepeat
|
|
ck::Sequence<4, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
|
2, // ABlockTransferSrcVectorDim
|
|
1, // ABlockTransferSrcScalarPerVector
|
|
8, // ABlockTransferDstScalarPerVector_AK1
|
|
1, // ABlockLdsExtraM
|
|
ck::Sequence<4, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
|
2, // BBlockTransferSrcVectorDim
|
|
1, // BBlockTransferSrcScalarPerVector
|
|
8, // BBlockTransferDstScalarPerVector_BK1
|
|
1, // BBlockLdsExtraN
|
|
1, // CShuffleMRepeatPerShuffle
|
|
1, // CShuffleNRepeatPerShuffle
|
|
ck::Sequence<1,
|
|
32,
|
|
1,
|
|
4>, // CDEShuffleBlockTransferClusterLengths
|
|
1, // CDEShuffleBlockTransferScalarPerVector_NPerBlock
|
|
ck::LoopScheduler::Default, // LoopSched
|
|
ck::PipelineVersion::v1>; // PipelineVer
|
|
|
|
// Generate instance string
|
|
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
|
|
|
// Expected string with all 46 template parameters
|
|
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
|
|
"<2" // NDimSpatial
|
|
",GNHWC" // ALayout
|
|
",GKYXC" // BLayout
|
|
",EmptyTuple" // DsLayout
|
|
",GNHWK" // ELayout
|
|
",fp16" // ADataType
|
|
",fp16" // BDataType
|
|
",fp32" // AccDataType
|
|
",fp16" // CShuffleDataType
|
|
",EmptyTuple" // DsDataType
|
|
",fp16" // EDataType
|
|
",PassThrough" // AElementwiseOperation
|
|
",PassThrough" // BElementwiseOperation
|
|
",PassThrough" // CDEElementwiseOperation
|
|
",Default" // ConvForwardSpecialization
|
|
",MNKPadding" // GemmSpec
|
|
",1" // NumGemmKPrefetchStage
|
|
",128" // BlockSize
|
|
",64" // MPerBlock
|
|
",64" // NPerBlock
|
|
",32" // KPerBlock
|
|
",8" // K1
|
|
",16" // MPerWmma
|
|
",16" // NPerWmma
|
|
",2" // MRepeat
|
|
",2" // NRepeat
|
|
",Seq(4,32,1)" // ABlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
|
",2" // ABlockTransferSrcVectorDim
|
|
",1" // ABlockTransferSrcScalarPerVector
|
|
",8" // ABlockTransferDstScalarPerVector_AK1
|
|
",true" // ABlockLdsExtraM
|
|
",Seq(4,32,1)" // BBlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
|
",2" // BBlockTransferSrcVectorDim
|
|
",1" // BBlockTransferSrcScalarPerVector
|
|
",8" // BBlockTransferDstScalarPerVector_BK1
|
|
",true" // BBlockLdsExtraN
|
|
",1" // CShuffleMRepeatPerShuffle
|
|
",1" // CShuffleNRepeatPerShuffle
|
|
",Seq(1,32,1,4)" // CDEShuffleBlockTransferClusterLengths
|
|
",1" // CDEShuffleBlockTransferScalarPerVector_NPerBlock
|
|
",Default" // LoopSched
|
|
",v1>"; // PipelineVer
|
|
|
|
// Verify the generated string matches exactly
|
|
EXPECT_EQ(instance_str, expected_str);
|
|
}
|
|
|
|
TEST(InstanceTraits, WmmaV3InstanceStringReturnsCorrectFormat)
|
|
{
|
|
using DeviceInstance =
|
|
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
|
2, // NDimSpatial
|
|
ck::tensor_layout::convolution::GNHWC, // ALayout
|
|
ck::tensor_layout::convolution::GKYXC, // BLayout
|
|
ck::Tuple<>, // DsLayout
|
|
ck::tensor_layout::convolution::GNHWK, // ELayout
|
|
ck::half_t, // ADataType
|
|
ck::half_t, // BDataType
|
|
float, // AccDataType
|
|
ck::half_t, // CShuffleDataType
|
|
ck::Tuple<>, // DsDataType
|
|
ck::half_t, // EDataType
|
|
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
|
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
|
Default, // ConvForwardSpec
|
|
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
|
|
64, // BlockSize
|
|
64, // MPerBlock
|
|
64, // NPerBlock
|
|
32, // KPerBlock
|
|
8, // AK1
|
|
8, // BK1
|
|
16, // MPerWmma
|
|
16, // NPerWmma
|
|
4, // MRepeat
|
|
2, // NRepeat
|
|
ck::Sequence<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
|
2, // ABlockTransferSrcVectorDim
|
|
1, // ABlockTransferSrcScalarPerVector
|
|
8, // ABlockTransferDstScalarPerVector_AK1
|
|
1, // ABlockLdsExtraM
|
|
ck::Sequence<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
|
2, // BBlockTransferSrcVectorDim
|
|
1, // BBlockTransferSrcScalarPerVector
|
|
8, // BBlockTransferDstScalarPerVector_BK1
|
|
1, // BBlockLdsExtraN
|
|
1, // CShuffleMRepeatPerShuffle
|
|
1, // CShuffleNRepeatPerShuffle
|
|
ck::Sequence<1, 16, 1, 4>, // CDEBlockTransferClusterLengths
|
|
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
|
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
|
ck::BlockGemmPipelineVersion::v1>; // BlkGemmPipelineVer
|
|
|
|
// Generate instance string
|
|
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
|
|
|
// Expected string with all template parameters
|
|
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3"
|
|
"<2" // NDimSpatial
|
|
",GNHWC" // ALayout
|
|
",GKYXC" // BLayout
|
|
",EmptyTuple" // DsLayout
|
|
",GNHWK" // ELayout
|
|
",fp16" // ADataType
|
|
",fp16" // BDataType
|
|
",fp32" // AccDataType
|
|
",fp16" // CShuffleDataType
|
|
",EmptyTuple" // DsDataType
|
|
",fp16" // EDataType
|
|
",PassThrough" // AElementwiseOperation
|
|
",PassThrough" // BElementwiseOperation
|
|
",PassThrough" // CDEElementwiseOperation
|
|
",Default" // ConvForwardSpecialization
|
|
",MNKPadding" // GemmSpec
|
|
",64" // BlockSize
|
|
",64" // MPerBlock
|
|
",64" // NPerBlock
|
|
",32" // KPerBlock
|
|
",8" // AK1
|
|
",8" // BK1
|
|
",16" // MPerWmma
|
|
",16" // NPerWmma
|
|
",4" // MRepeat
|
|
",2" // NRepeat
|
|
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
|
",2" // ABlockTransferSrcVectorDim
|
|
",1" // ABlockTransferSrcScalarPerVector
|
|
",8" // ABlockTransferDstScalarPerVector_AK1
|
|
",true" // ABlockLdsExtraM
|
|
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
|
",2" // BBlockTransferSrcVectorDim
|
|
",1" // BBlockTransferSrcScalarPerVector
|
|
",8" // BBlockTransferDstScalarPerVector_BK1
|
|
",true" // BBlockLdsExtraN
|
|
",1" // CShuffleMRepeatPerShuffle
|
|
",1" // CShuffleNRepeatPerShuffle
|
|
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
|
|
",1" // CDEBlockTransferScalarPerVector_NPerBlock
|
|
",Intrawave" // BlkGemmPipeSched
|
|
",v1" // BlkGemmPipelineVer
|
|
",true" // UseThreadTileTransfer
|
|
",fp16" // AComputeDataType
|
|
",fp16" // BComputeDataType
|
|
",1>"; // NumGroupsToMerge
|
|
|
|
// Verify the generated string matches exactly
|
|
EXPECT_EQ(instance_str, expected_str);
|
|
}
|
|
|
|
TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat)
|
|
{
|
|
using DeviceInstance =
|
|
ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
|
2, // NDimSpatial
|
|
ck::half_t, // ADataType
|
|
ck::half_t, // BDataType
|
|
ck::Tuple<>, // DsDataType
|
|
ck::half_t, // EDataType
|
|
float, // AccDataType
|
|
ck::tensor_layout::convolution::GNHWC, // ALayout
|
|
ck::tensor_layout::convolution::GKYXC, // BLayout
|
|
ck::Tuple<>, // DsLayout
|
|
ck::tensor_layout::convolution::GNHWK, // ELayout
|
|
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
|
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
|
Default, // ConvForwardSpecialization
|
|
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
|
|
8, // BlockSize
|
|
16, // MPerBlock
|
|
4, // NPerBlock
|
|
2, // K0PerBlock
|
|
1, // K1
|
|
1, // M1PerThread
|
|
2, // N1PerThread
|
|
1, // KPerThread
|
|
ck::Sequence<4, 2>, // M1N1ThreadClusterM1Xs
|
|
ck::Sequence<1, 1>, // M1N1ThreadClusterN1Xs
|
|
ck::Sequence<2, 1, 2, 1>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
|
|
ck::Sequence<1, 1, 8, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
|
|
ck::Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
|
|
ck::Sequence<1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
|
|
ck::Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
|
ck::Sequence<1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
|
|
ck::Sequence<1, 1, 1, 1>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
|
|
ck::Sequence<2, 1, 4, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
|
|
ck::Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
|
|
ck::Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
|
|
ck::Sequence<1, 1, 1, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
|
|
ck::Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
|
ck::Sequence<1, 1, 1, 1>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
|
|
ck::Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
|
5, // CThreadTransferSrcDstVectorDim
|
|
1>; // CThreadTransferDstScalarPerVector
|
|
|
|
// Generate instance string
|
|
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
|
|
|
// Expected string with all 42 template parameters
|
|
std::string expected_str = "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
|
|
"<2" // NDimSpatial
|
|
",fp16" // ADataType
|
|
",fp16" // BDataType
|
|
",EmptyTuple" // DsDataType
|
|
",fp16" // EDataType
|
|
",fp32" // AccDataType
|
|
",GNHWC" // ALayout
|
|
",GKYXC" // BLayout
|
|
",EmptyTuple" // DsLayout
|
|
",GNHWK" // ELayout
|
|
",PassThrough" // AElementwiseOperation
|
|
",PassThrough" // BElementwiseOperation
|
|
",PassThrough" // CDEElementwiseOperation
|
|
",Default" // ConvForwardSpecialization
|
|
",MNKPadding" // GemmSpec
|
|
",8" // BlockSize
|
|
",16" // MPerBlock
|
|
",4" // NPerBlock
|
|
",2" // K0PerBlock
|
|
",1" // K1
|
|
",1" // M1PerThread
|
|
",2" // N1PerThread
|
|
",1" // KPerThread
|
|
",Seq(4,2)" // M1N1ThreadClusterM1Xs
|
|
",Seq(1,1)" // M1N1ThreadClusterN1Xs
|
|
",Seq(2,1,2,1)" // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
|
|
",Seq(1,1,8,1)" // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
|
|
",Seq(1,2,0,3)" // ABlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,2,0,3)" // ABlockTransferSrcAccessOrder
|
|
",Seq(1,1,1,1)" // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
|
|
",Seq(1,2,0,3)" // ABlockTransferSrcVectorTensorContiguousDimOrder
|
|
",Seq(1,1,1,1)" // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
|
|
",Seq(1,1,1,1)" // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
|
|
",Seq(2,1,4,1)" // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
|
|
",Seq(1,2,0,3)" // BBlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,2,0,3)" // BBlockTransferSrcAccessOrder
|
|
",Seq(1,1,1,1)" // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
|
|
",Seq(1,2,0,3)" // BBlockTransferSrcVectorTensorContiguousDimOrder
|
|
",Seq(1,1,1,1)" // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
|
|
",Seq(0,1,2,3,4,5)" // CThreadTransferSrcDstAccessOrder
|
|
",5" // CThreadTransferSrcDstVectorDim
|
|
",1>"; // CThreadTransferDstScalarPerVector
|
|
|
|
// Verify the generated string matches exactly
|
|
EXPECT_EQ(instance_str, expected_str);
|
|
}
|
|
|
|
TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
|
{
|
|
using GroupedConvTraitsType =
|
|
ck_tile::GroupedConvTraits<2 /*NDimSpatial*/,
|
|
ck_tile::ConvolutionSpecialization::Default /*ConvSpec*/,
|
|
ck_tile::tensor_layout::convolution::NHWGC /*InLayout*/,
|
|
ck_tile::tensor_layout::convolution::GKYXC /*WeiLayout*/,
|
|
ck_tile::tuple<> /*DsLayout*/,
|
|
ck_tile::tensor_layout::convolution::NHWGK /*OutLayout*/,
|
|
4 /*VectorSizeA*/,
|
|
4 /*VectorSizeB*/,
|
|
4 /*VectorSizeC*/,
|
|
1 /*NumGroupsToMerge*/,
|
|
false /*EnableSplitImage*/,
|
|
false /*ExplicitGemm*/>;
|
|
|
|
using GemmShape = ck_tile::TileGemmShape<
|
|
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
|
|
ck_tile::sequence<4 /*M_Warp*/, 1 /*N_Warp*/, 1 /*K_Warp*/>,
|
|
ck_tile::sequence<16 /*M_Warp_Tile*/, 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/>>;
|
|
|
|
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
|
GemmShape,
|
|
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
|
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
|
|
|
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
|
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
|
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
|
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
|
false /*DoubleSmemBuffer*/,
|
|
typename GroupedConvTraitsType::AsLayoutFwd,
|
|
typename GroupedConvTraitsType::BsLayoutFwd,
|
|
typename GroupedConvTraitsType::CLayoutFwd,
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
|
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
|
1 /*NumWaveGroups*/>;
|
|
|
|
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
|
ck_tile::bf16_t /*InDataType*/,
|
|
ck_tile::bf16_t /*WeiDataType*/,
|
|
float /*AccDataType*/,
|
|
GemmShape,
|
|
GemmUniversalTraits,
|
|
ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/,
|
|
ck_tile::element_wise::PassThrough /*AElementwiseOperation*/,
|
|
ck_tile::element_wise::PassThrough /*BElementwiseOperation*/,
|
|
ck_tile::bf16_t /*OutDataType*/,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeA,
|
|
GroupedConvTraitsType::VectorSizeB>;
|
|
|
|
using GemmPipeline = typename ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
|
|
|
using ConvEpilogue = ck_tile::CShuffleEpilogue<
|
|
ck_tile::CShuffleEpilogueProblem<ck_tile::bf16_t /*InDataType*/,
|
|
ck_tile::bf16_t /*WeiDataType*/,
|
|
ck_tile::tuple<> /*DsDataType*/,
|
|
float /*AccDataType*/,
|
|
ck_tile::bf16_t /*OutDataType*/,
|
|
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
|
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
|
ck_tile::element_wise::PassThrough /*CDElementWise*/,
|
|
128 /*MPerBlock*/,
|
|
128 /*NPerBlock*/,
|
|
4 /*M_Warp*/,
|
|
1 /*N_Warp*/,
|
|
16 /*M_Warp_Tile*/,
|
|
16 /*N_Warp_Tile*/,
|
|
16 /*K_Warp_Tile*/,
|
|
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
|
1 /*kNumWaveGroups*/,
|
|
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
|
GroupedConvTraitsType::VectorSizeC>>;
|
|
|
|
using GroupedConvFwdKernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
|
TilePartitioner,
|
|
GemmPipeline,
|
|
ConvEpilogue>;
|
|
|
|
std::string instance_str = ck_tile::reflect::instance_string<GroupedConvFwdKernel>();
|
|
|
|
std::string expected_str = "GroupedConvolutionForwardKernel"
|
|
"<2" // NDimSpatial
|
|
",Default" // ConvSpecialization
|
|
",NHWGC" // InLayout
|
|
",GKYXC" // WeiLayout
|
|
",EmptyTuple" // DsLayout
|
|
",NHWGK" // OutLayout
|
|
",4" // VectorSizeA
|
|
",4" // VectorSizeB
|
|
",4" // VectorSizeC
|
|
",1" // NumGroupsToMerge
|
|
",0" // EnableSplitImage
|
|
",0" // ExplicitGemm
|
|
",128" // MPerBlock
|
|
",128" // NPerBlock
|
|
",32" // KPerBlock
|
|
",4" // MWarp
|
|
",1" // NWarp
|
|
",1" // KWarp
|
|
",16" // MWarpTile
|
|
",16" // NWarpTile
|
|
",16" // KWarpTile
|
|
",bf16" // ADataType
|
|
",bf16" // BDataType
|
|
",COMPUTE_V3" // BlkGemmPipelineVer
|
|
",Intrawave" // BlkGemmPipeSched
|
|
",0" // DoubleSmemBuffer
|
|
",1" // NumWaveGroups
|
|
",fp32" // AccDataType
|
|
",bf16" // EDataType
|
|
",EmptyTuple" // DsDataType
|
|
",PassThrough" // CDEElementwiseOperation
|
|
">";
|
|
|
|
EXPECT_EQ(instance_str, expected_str);
|
|
}
|
|
|
|
} // anonymous namespace
|