Files
composable_kernel/experimental/builder/test/test_fwd_instance_traits.cpp
John Shumway f5b0af2272 Simplify includes for CK builder reflection (#3357)
We only want to import enums and types into the builder reflection code. But, some of the enums are included in much larger files or even big trees of include files. This leads to unintended mixing of code and very confusing interactions and symbol conflicts. We organize the includes and extract two new enum-only headers to help with decoupling in CK. This refactoring is critical if we want to include reflection in a device-operator "describe" method.

* Remove a few unnecessary includes from headers in builder/reflect/.
* Extract enums scheduler and pipeline to their own headers so they can be used without importing other code.
* Order includes alphabetically for better organization.

The immediate goal is to unblock reflection integration, and this type of cleanup helps the flexibility and robustness of the CK header library.
2025-12-05 07:44:10 -08:00

848 lines
53 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
namespace {
using ::testing::ElementsAre;
// NOTE: The V3ExtractsAllFieldsCorrectly test below performs detailed field extraction testing
// for the V3 variant as a reference implementation. For new InstanceTraits specializations,
// only the instance_string() functionality needs to be tested. Each new specialization should have:
// 1. A test using instance_string<T>() directly (in this file)
// 2. A test using GetInstanceString() through base class pointer (in separate
// test_get_instance_string_*.cpp file) This prevents test duplication while ensuring both access
// methods work correctly.
TEST(InstanceTraits, V3ExtractsAllFieldsCorrectly)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
false>;
// Use InstanceTraits to extract compile-time information
using Traits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
// Verify spatial dimension
EXPECT_EQ(Traits::kSpatialDim, 2);
// Verify block configuration
EXPECT_EQ(Traits::kBlockSize, 256);
EXPECT_EQ(Traits::kMPerBlock, 128);
EXPECT_EQ(Traits::kNPerBlock, 128);
EXPECT_EQ(Traits::kKPerBlock, 16);
// Verify tuning parameters
EXPECT_EQ(Traits::kAK1, 8);
EXPECT_EQ(Traits::kBK1, 8);
EXPECT_EQ(Traits::kMPerXDL, 32);
EXPECT_EQ(Traits::kNPerXDL, 32);
EXPECT_EQ(Traits::kMXdlPerWave, 4);
EXPECT_EQ(Traits::kNXdlPerWave, 4);
// Verify A block transfer parameters
EXPECT_EQ(Traits::kABlockTransferSrcVectorDim, 2);
EXPECT_EQ(Traits::kABlockTransferSrcScalarPerVector, 8);
EXPECT_EQ(Traits::kABlockTransferDstScalarPerVectorK1, 8);
EXPECT_EQ(Traits::kABlockLdsExtraM, 1);
// Verify B block transfer parameters
EXPECT_EQ(Traits::kBBlockTransferSrcVectorDim, 2);
EXPECT_EQ(Traits::kBBlockTransferSrcScalarPerVector, 8);
EXPECT_EQ(Traits::kBBlockTransferDstScalarPerVectorK1, 8);
EXPECT_EQ(Traits::kBBlockLdsExtraN, 1);
// Verify C shuffle parameters
EXPECT_EQ(Traits::kCShuffleMXdlPerWavePerShuffle, 1);
EXPECT_EQ(Traits::kCShuffleNXdlPerWavePerShuffle, 1);
EXPECT_EQ(Traits::kCBlockTransferScalarPerVector, 8);
// Verify pipeline configuration
EXPECT_EQ(Traits::kPipelineScheduler, ck::BlockGemmPipelineScheduler::Intrawave);
EXPECT_EQ(Traits::kPipelineVersion, ck::BlockGemmPipelineVersion::v1);
// Verify data types using std::is_same
EXPECT_TRUE((std::is_same<Traits::ADataType, ck::half_t>::value));
EXPECT_TRUE((std::is_same<Traits::BDataType, ck::half_t>::value));
EXPECT_TRUE((std::is_same<Traits::AccDataType, float>::value));
EXPECT_TRUE((std::is_same<Traits::EDataType, ck::half_t>::value));
// Verify layout types
EXPECT_TRUE((std::is_same<Traits::ALayout, ck::tensor_layout::convolution::GNHWC>::value));
EXPECT_TRUE((std::is_same<Traits::BLayout, ck::tensor_layout::convolution::GKYXC>::value));
EXPECT_TRUE((std::is_same<Traits::ELayout, ck::tensor_layout::convolution::GNHWK>::value));
// Verify all array values for thread cluster lengths using googlemock matchers
EXPECT_THAT(Traits::kAThreadClusterLengths, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::kBThreadClusterLengths, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::kCThreadClusterLengths, ElementsAre(1, 32, 1, 8));
// Verify A block transfer arrange order and access order arrays
EXPECT_THAT(Traits::kAThreadClusterArrangeOrder, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::kABlockTransferSrcAccessOrder, ElementsAre(1, 0, 2));
// Verify B block transfer arrange order and access order arrays
EXPECT_THAT(Traits::kBThreadClusterArrangeOrder, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::kBBlockTransferSrcAccessOrder, ElementsAre(1, 0, 2));
// Verify additional data types
EXPECT_TRUE((std::is_same<Traits::CShuffleDataType, ck::half_t>::value));
EXPECT_TRUE((std::is_same<Traits::DsDataType, ck::Tuple<>>::value));
EXPECT_TRUE((std::is_same<Traits::AComputeDataType, ck::half_t>::value));
EXPECT_TRUE((std::is_same<Traits::BComputeDataType, ck::half_t>::value));
// Verify additional layout types
EXPECT_TRUE((std::is_same<Traits::DsLayout, ck::Tuple<>>::value));
// Verify element-wise operations
EXPECT_TRUE((std::is_same<Traits::AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>::value));
EXPECT_TRUE((std::is_same<Traits::BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>::value));
EXPECT_TRUE((std::is_same<Traits::CDEElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>::value));
}
TEST(InstanceTraits, V3InstanceStringReturnsCorrectFormat)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
false>; // DirectLoad
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",Default" // GemmSpec
",256" // BlockSize
",128" // MPerBlock
",128" // NPerBlock
",16" // KPerBlock
",8" // AK1
",8" // BK1
",32" // MPerXDL
",32" // NPerXDL
",4" // MXdlPerWave
",4" // NXdlPerWave
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",8" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",1" // ABlockLdsExtraM
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",8" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",1" // BBlockLdsExtraN
",1" // CShuffleMXdlPerWavePerShuffle
",1" // CShuffleNXdlPerWavePerShuffle
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
",8" // CDEBlockTransferScalarPerVector_NPerBlock
",Intrawave" // BlkGemmPipeSched
",v1" // BlkGemmPipelineVer
",fp16" // AComputeDataType
",fp16" // BComputeDataType
",false>"; // DirectLoad
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, BaseInstanceStringReturnsCorrectFormat)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default, // LoopSched
1>; // NumGroupsToMerge
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",Default" // GemmSpec
",1" // NumGemmKPrefetchStage
",256" // BlockSize
",128" // MPerBlock
",128" // NPerBlock
",16" // KPerBlock
",8" // AK1
",8" // BK1
",32" // MPerXDL
",32" // NPerXDL
",4" // MXdlPerWave
",4" // NXdlPerWave
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",8" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",1" // ABlockLdsExtraM
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",8" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",1" // BBlockLdsExtraN
",1" // CShuffleMXdlPerWavePerShuffle
",1" // CShuffleNXdlPerWavePerShuffle
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
",8" // CDEBlockTransferScalarPerVector_NPerBlock
",fp16" // AComputeDataType
",fp16" // BComputeDataType
",Default" // LoopSched
",1>"; // NumGroupsToMerge
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, LargeTensorInstanceStringReturnsCorrectFormat)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default>; // LoopSched
// Generate instance string
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
// Expected string with all 48 template parameters
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",Default" // GemmSpec
",1" // NumGemmKPrefetchStage
",256" // BlockSize
",128" // MPerBlock
",128" // NPerBlock
",16" // KPerBlock
",8" // AK1
",8" // BK1
",32" // MPerXDL
",32" // NPerXDL
",4" // MXdlPerWave
",4" // NXdlPerWave
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",8" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",1" // ABlockLdsExtraM
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",8" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",1" // BBlockLdsExtraN
",1" // CShuffleMXdlPerWavePerShuffle
",1" // CShuffleNXdlPerWavePerShuffle
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
",8" // CDEBlockTransferScalarPerVector_NPerBlock
",fp16" // AComputeDataType
",fp16" // BComputeDataType
",Default>"; // LoopSched
// Verify the generated string matches exactly
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, WmmaInstanceStringReturnsCorrectFormat)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
1, // NumGemmKPrefetchStage
128, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
2, // MRepeat
2, // NRepeat
ck::Sequence<4, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
ck::Sequence<1,
32,
1,
4>, // CDEShuffleBlockTransferClusterLengths
1, // CDEShuffleBlockTransferScalarPerVector_NPerBlock
ck::LoopScheduler::Default, // LoopSched
ck::PipelineVersion::v1>; // PipelineVer
// Generate instance string
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
// Expected string with all 46 template parameters
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",MNKPadding" // GemmSpec
",1" // NumGemmKPrefetchStage
",128" // BlockSize
",64" // MPerBlock
",64" // NPerBlock
",32" // KPerBlock
",8" // K1
",16" // MPerWmma
",16" // NPerWmma
",2" // MRepeat
",2" // NRepeat
",Seq(4,32,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",1" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",true" // ABlockLdsExtraM
",Seq(4,32,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",1" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",true" // BBlockLdsExtraN
",1" // CShuffleMRepeatPerShuffle
",1" // CShuffleNRepeatPerShuffle
",Seq(1,32,1,4)" // CDEShuffleBlockTransferClusterLengths
",1" // CDEShuffleBlockTransferScalarPerVector_NPerBlock
",Default" // LoopSched
",v1>"; // PipelineVer
// Verify the generated string matches exactly
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
2, // NDimSpatial
ck::half_t, // ADataType
ck::half_t, // BDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
float, // AccDataType
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
8, // BlockSize
16, // MPerBlock
4, // NPerBlock
2, // K0PerBlock
1, // K1
1, // M1PerThread
2, // N1PerThread
1, // KPerThread
ck::Sequence<4, 2>, // M1N1ThreadClusterM1Xs
ck::Sequence<1, 1>, // M1N1ThreadClusterN1Xs
ck::Sequence<2, 1, 2, 1>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
ck::Sequence<1, 1, 8, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
ck::Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
ck::Sequence<1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
ck::Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
ck::Sequence<1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
ck::Sequence<1, 1, 1, 1>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
ck::Sequence<2, 1, 4, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
ck::Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
ck::Sequence<1, 1, 1, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
ck::Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
ck::Sequence<1, 1, 1, 1>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
ck::Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
1>; // CThreadTransferDstScalarPerVector
// Generate instance string
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
// Expected string with all 42 template parameters
std::string expected_str = "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
"<2" // NDimSpatial
",fp16" // ADataType
",fp16" // BDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",fp32" // AccDataType
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",MNKPadding" // GemmSpec
",8" // BlockSize
",16" // MPerBlock
",4" // NPerBlock
",2" // K0PerBlock
",1" // K1
",1" // M1PerThread
",2" // N1PerThread
",1" // KPerThread
",Seq(4,2)" // M1N1ThreadClusterM1Xs
",Seq(1,1)" // M1N1ThreadClusterN1Xs
",Seq(2,1,2,1)" // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
",Seq(1,1,8,1)" // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
",Seq(1,2,0,3)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,2,0,3)" // ABlockTransferSrcAccessOrder
",Seq(1,1,1,1)" // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
",Seq(1,2,0,3)" // ABlockTransferSrcVectorTensorContiguousDimOrder
",Seq(1,1,1,1)" // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
",Seq(1,1,1,1)" // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
",Seq(2,1,4,1)" // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
",Seq(1,2,0,3)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,2,0,3)" // BBlockTransferSrcAccessOrder
",Seq(1,1,1,1)" // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
",Seq(1,2,0,3)" // BBlockTransferSrcVectorTensorContiguousDimOrder
",Seq(1,1,1,1)" // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
",Seq(0,1,2,3,4,5)" // CThreadTransferSrcDstAccessOrder
",5" // CThreadTransferSrcDstVectorDim
",1>"; // CThreadTransferDstScalarPerVector
// Verify the generated string matches exactly
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
{
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<2 /*NDimSpatial*/,
ck_tile::ConvolutionSpecialization::Default /*ConvSpec*/,
ck_tile::tensor_layout::convolution::NHWGC /*InLayout*/,
ck_tile::tensor_layout::convolution::GKYXC /*WeiLayout*/,
ck_tile::tuple<> /*DsLayout*/,
ck_tile::tensor_layout::convolution::NHWGK /*OutLayout*/,
4 /*VectorSizeA*/,
4 /*VectorSizeB*/,
4 /*VectorSizeC*/,
1 /*NumGroupsToMerge*/,
false /*EnableSplitImage*/,
false /*ExplicitGemm*/>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
ck_tile::sequence<4 /*M_Warp*/, 1 /*N_Warp*/, 1 /*K_Warp*/>,
ck_tile::sequence<16 /*M_Warp_Tile*/, 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/>>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
false /*DoubleSmemBuffer*/,
typename GroupedConvTraitsType::AsLayoutFwd,
typename GroupedConvTraitsType::BsLayoutFwd,
typename GroupedConvTraitsType::CLayoutFwd,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
1 /*NumWaveGroups*/>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ck_tile::bf16_t /*InDataType*/,
ck_tile::bf16_t /*WeiDataType*/,
float /*AccDataType*/,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/,
ck_tile::element_wise::PassThrough /*AElementwiseOperation*/,
ck_tile::element_wise::PassThrough /*BElementwiseOperation*/,
ck_tile::bf16_t /*OutDataType*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ck_tile::bf16_t /*InDataType*/,
ck_tile::bf16_t /*WeiDataType*/,
ck_tile::tuple<> /*DsDataType*/,
float /*AccDataType*/,
ck_tile::bf16_t /*OutDataType*/,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
ck_tile::element_wise::PassThrough /*CDElementWise*/,
128 /*MPerBlock*/,
128 /*NPerBlock*/,
4 /*M_Warp*/,
1 /*N_Warp*/,
16 /*M_Warp_Tile*/,
16 /*N_Warp_Tile*/,
16 /*K_Warp_Tile*/,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ck_tile::memory_operation_enum::set /*memory_operation*/,
1 /*kNumWaveGroups*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using GroupedConvFwdKernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
std::string instance_str = ck_tile::reflect::instance_string<GroupedConvFwdKernel>();
std::string expected_str = "GroupedConvolutionForwardKernel"
"<2" // NDimSpatial
",Default" // ConvSpecialization
",NHWGC" // InLayout
",GKYXC" // WeiLayout
",EmptyTuple" // DsLayout
",NHWGK" // OutLayout
",4" // VectorSizeA
",4" // VectorSizeB
",4" // VectorSizeC
",1" // NumGroupsToMerge
",0" // EnableSplitImage
",0" // ExplicitGemm
",128" // MPerBlock
",128" // NPerBlock
",32" // KPerBlock
",4" // MWarp
",1" // NWarp
",1" // KWarp
",16" // MWarpTile
",16" // NWarpTile
",16" // KWarpTile
",bf16" // ADataType
",bf16" // BDataType
",COMPUTE_V3" // BlkGemmPipelineVer
",Intrawave" // BlkGemmPipeSched
",0" // DoubleSmemBuffer
",1" // NumWaveGroups
",fp32" // AccDataType
",bf16" // EDataType
",EmptyTuple" // DsDataType
",PassThrough" // CDEElementwiseOperation
">";
EXPECT_EQ(instance_str, expected_str);
}
} // anonymous namespace