mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add the last two forward instance traits. (#3134)
* Add InstanceTraits for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle * Add InstanceTraits for kernel_grouped_conv_fwd_dl_multiple_d * A few small changes to fix broken instance traits.
This commit is contained in:
@@ -28,7 +28,9 @@ add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
|
||||
add_ck_builder_test(test_ckb_get_instance_string
|
||||
test_get_instance_string_fwd_grp_conv_v3.cpp
|
||||
test_get_instance_string_fwd_grp_conv.cpp
|
||||
test_get_instance_string_fwd_grp_conv_large_tensor.cpp)
|
||||
test_get_instance_string_fwd_grp_conv_large_tensor.cpp
|
||||
test_get_instance_string_fwd_grp_conv_wmma.cpp
|
||||
test_get_instance_string_fwd_grp_conv_dl.cpp)
|
||||
|
||||
# Testing the fwd convolution builder requires kernel compilation.
|
||||
# To enable parallel compilation, the individual tests are split into separate files.
|
||||
|
||||
@@ -9,12 +9,21 @@
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
TEST(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
|
||||
// NOTE: The V3ExtractsAllFieldsCorrectly test below performs detailed field extraction testing
|
||||
// for the V3 variant as a reference implementation. For new InstanceTraits specializations,
|
||||
// only the instance_string() functionality needs to be tested. Each new specialization should have:
|
||||
// 1. A test using instance_string<T>() directly (in this file)
|
||||
// 2. A test using GetInstanceString() through base class pointer (in separate
|
||||
// test_get_instance_string_*.cpp file) This prevents test duplication while ensuring both access
|
||||
// methods work correctly.
|
||||
TEST(InstanceTraits, V3ExtractsAllFieldsCorrectly)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
@@ -70,7 +79,8 @@ TEST(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t>; // BComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
false>;
|
||||
|
||||
// Use InstanceTraits to extract compile-time information
|
||||
using Traits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
|
||||
@@ -155,7 +165,7 @@ TEST(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
|
||||
ck::tensor_operation::element_wise::PassThrough>::value));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsTest, V3InstanceStringGeneration)
|
||||
TEST(InstanceTraits, V3InstanceStringReturnsCorrectFormat)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
@@ -211,7 +221,8 @@ TEST(InstanceTraitsTest, V3InstanceStringGeneration)
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t>; // BComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
false>; // DirectLoad
|
||||
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
@@ -263,12 +274,13 @@ TEST(InstanceTraitsTest, V3InstanceStringGeneration)
|
||||
",Intrawave" // BlkGemmPipeSched
|
||||
",v1" // BlkGemmPipelineVer
|
||||
",fp16" // AComputeDataType
|
||||
",fp16>"; // BComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",false>"; // DirectLoad
|
||||
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsTest, BaseInstanceStringGeneration)
|
||||
TEST(InstanceTraits, BaseInstanceStringReturnsCorrectFormat)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
@@ -382,7 +394,7 @@ TEST(InstanceTraitsTest, BaseInstanceStringGeneration)
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsTest, LargeTensorInstanceStringGeneration)
|
||||
TEST(InstanceTraits, LargeTensorInstanceStringReturnsCorrectFormat)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
@@ -497,4 +509,215 @@ TEST(InstanceTraitsTest, LargeTensorInstanceStringGeneration)
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraits, WmmaInstanceStringReturnsCorrectFormat)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
|
||||
1, // NumGemmKPrefetchStage
|
||||
128, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // K1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
2, // MRepeat
|
||||
2, // NRepeat
|
||||
ck::Sequence<4, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
4>, // CDEShuffleBlockTransferClusterLengths
|
||||
1, // CDEShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
ck::LoopScheduler::Default, // LoopSched
|
||||
ck::PipelineVersion::v1>; // PipelineVer
|
||||
|
||||
// Generate instance string
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
// Expected string with all 46 template parameters
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",MNKPadding" // GemmSpec
|
||||
",1" // NumGemmKPrefetchStage
|
||||
",128" // BlockSize
|
||||
",64" // MPerBlock
|
||||
",64" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
",8" // K1
|
||||
",16" // MPerWmma
|
||||
",16" // NPerWmma
|
||||
",2" // MRepeat
|
||||
",2" // NRepeat
|
||||
",Seq(4,32,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",1" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",true" // ABlockLdsExtraM
|
||||
",Seq(4,32,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",1" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",true" // BBlockLdsExtraN
|
||||
",1" // CShuffleMRepeatPerShuffle
|
||||
",1" // CShuffleNRepeatPerShuffle
|
||||
",Seq(1,32,1,4)" // CDEShuffleBlockTransferClusterLengths
|
||||
",1" // CDEShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
",Default" // LoopSched
|
||||
",v1>"; // PipelineVer
|
||||
|
||||
// Verify the generated string matches exactly
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
2, // NDimSpatial
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
float, // AccDataType
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
|
||||
8, // BlockSize
|
||||
16, // MPerBlock
|
||||
4, // NPerBlock
|
||||
2, // K0PerBlock
|
||||
1, // K1
|
||||
1, // M1PerThread
|
||||
2, // N1PerThread
|
||||
1, // KPerThread
|
||||
ck::Sequence<4, 2>, // M1N1ThreadClusterM1Xs
|
||||
ck::Sequence<1, 1>, // M1N1ThreadClusterN1Xs
|
||||
ck::Sequence<2, 1, 2, 1>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
|
||||
ck::Sequence<1, 1, 8, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
|
||||
ck::Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
|
||||
ck::Sequence<1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
|
||||
ck::Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
ck::Sequence<1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
|
||||
ck::Sequence<1, 1, 1, 1>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
|
||||
ck::Sequence<2, 1, 4, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
|
||||
ck::Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
|
||||
ck::Sequence<1, 1, 1, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
|
||||
ck::Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
ck::Sequence<1, 1, 1, 1>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
|
||||
ck::Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
1>; // CThreadTransferDstScalarPerVector
|
||||
|
||||
// Generate instance string
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
// Expected string with all 42 template parameters
|
||||
std::string expected_str = "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
|
||||
"<2" // NDimSpatial
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",fp32" // AccDataType
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",MNKPadding" // GemmSpec
|
||||
",8" // BlockSize
|
||||
",16" // MPerBlock
|
||||
",4" // NPerBlock
|
||||
",2" // K0PerBlock
|
||||
",1" // K1
|
||||
",1" // M1PerThread
|
||||
",2" // N1PerThread
|
||||
",1" // KPerThread
|
||||
",Seq(4,2)" // M1N1ThreadClusterM1Xs
|
||||
",Seq(1,1)" // M1N1ThreadClusterN1Xs
|
||||
",Seq(2,1,2,1)" // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
|
||||
",Seq(1,1,8,1)" // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
|
||||
",Seq(1,2,0,3)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,2,0,3)" // ABlockTransferSrcAccessOrder
|
||||
",Seq(1,1,1,1)" // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
|
||||
",Seq(1,2,0,3)" // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
",Seq(1,1,1,1)" // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
|
||||
",Seq(1,1,1,1)" // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
|
||||
",Seq(2,1,4,1)" // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
|
||||
",Seq(1,2,0,3)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,2,0,3)" // BBlockTransferSrcAccessOrder
|
||||
",Seq(1,1,1,1)" // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
|
||||
",Seq(1,2,0,3)" // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
",Seq(1,1,1,1)" // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
|
||||
",Seq(0,1,2,3,4,5)" // CThreadTransferSrcDstAccessOrder
|
||||
",5" // CThreadTransferSrcDstVectorDim
|
||||
",1>"; // CThreadTransferDstScalarPerVector
|
||||
|
||||
// Verify the generated string matches exactly
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_base.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer for non-V3 variant
|
||||
@@ -22,22 +22,8 @@ TEST(GetInstanceString, ReturnsStringForFwdGrpConvInstance)
|
||||
// Get the first instance from the tuple
|
||||
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
||||
|
||||
// Define the base class type using DeviceGroupedConvFwdMultipleABD
|
||||
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::half_t, // AComputeType
|
||||
ck::half_t>; // BComputeType
|
||||
// Define the base class type using the most general operator base
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
|
||||
// Create an instance of the derived class
|
||||
DeviceInstance device_instance;
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_base.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer for DL variant
|
||||
TEST(GetInstanceString, ReturnsStringForFwdGrpConvDlInstance)
|
||||
{
|
||||
// Use the template helper to get a working instance configuration
|
||||
using InstanceTuple =
|
||||
ck::tensor_operation::device::instance::device_grouped_conv2d_fwd_dl_f16_instances<
|
||||
ck::tensor_operation::device::instance::GNHWC, // InLayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // WeiLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // OutLayout
|
||||
ck::Tuple<>, // DsDatatype
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementOp
|
||||
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvSpec
|
||||
|
||||
// Get the first instance from the tuple
|
||||
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
||||
|
||||
// Define the base class type using the most general operator base
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
|
||||
// Create an instance of the derived class
|
||||
DeviceInstance device_instance;
|
||||
|
||||
// Get a pointer to the base class
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
// Call GetInstanceString through the base class pointer
|
||||
std::string instance_str = base_ptr->GetInstanceString();
|
||||
|
||||
// Expected complete instance string based on the first instance from
|
||||
// device_grouped_conv2d_fwd_dl_f16_instances
|
||||
std::string expected_str = "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
|
||||
"<2" // NDimSpatial
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",fp32" // AccDataType
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",MNKPadding" // GemmSpec
|
||||
",8" // BlockSize
|
||||
",16" // MPerBlock
|
||||
",4" // NPerBlock
|
||||
",2" // K0PerBlock
|
||||
",1" // K1
|
||||
",1" // M1PerThread
|
||||
",2" // N1PerThread
|
||||
",1" // KPerThread
|
||||
",Seq(4,2)" // M1N1ThreadClusterM1Xs
|
||||
",Seq(1,1)" // M1N1ThreadClusterN1Xs
|
||||
",Seq(2,1,2,1)" // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
|
||||
",Seq(1,1,8,1)" // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
|
||||
",Seq(1,2,0,3)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,2,0,3)" // ABlockTransferSrcAccessOrder
|
||||
",Seq(1,1,1,1)" // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
|
||||
",Seq(1,2,0,3)" // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
",Seq(1,1,1,1)" // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
|
||||
",Seq(1,1,1,1)" // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
|
||||
",Seq(2,1,4,1)" // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
|
||||
",Seq(1,2,0,3)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,2,0,3)" // BBlockTransferSrcAccessOrder
|
||||
",Seq(1,1,1,1)" // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
|
||||
",Seq(1,2,0,3)" // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
",Seq(1,1,1,1)" // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
|
||||
",Seq(0,1,2,3,4,5)" // CThreadTransferSrcDstAccessOrder
|
||||
",5" // CThreadTransferSrcDstVectorDim
|
||||
",1>"; // CThreadTransferDstScalarPerVector
|
||||
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_base.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer for large tensor variant
|
||||
@@ -22,22 +22,8 @@ TEST(GetInstanceString, ReturnsStringForFwdGrpConvLargeTensorInstance)
|
||||
// Get the first instance from the tuple
|
||||
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
||||
|
||||
// Define the base class type using DeviceGroupedConvFwdMultipleABD
|
||||
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::half_t, // AComputeType
|
||||
ck::half_t>; // BComputeType
|
||||
// Define the base class type using the most general operator base
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
|
||||
// Create an instance of the derived class
|
||||
DeviceInstance device_instance;
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_base.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer for V3 variant
|
||||
@@ -22,22 +22,8 @@ TEST(GetInstanceString, ReturnsStringForFwdGrpConvV3Instance)
|
||||
// Get the first instance from the tuple
|
||||
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
||||
|
||||
// Define the base class type using DeviceGroupedConvFwdMultipleABD
|
||||
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::half_t, // AComputeType
|
||||
ck::half_t>; // BComputeType
|
||||
// Define the base class type using the most general operator base
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
|
||||
// Create an instance of the derived class
|
||||
DeviceInstance device_instance;
|
||||
@@ -99,6 +85,7 @@ TEST(GetInstanceString, ReturnsStringForFwdGrpConvV3Instance)
|
||||
",Intrawave" // BlkGemmPipeSched
|
||||
",v4" // BlkGemmPipelineVer
|
||||
",fp16" // AComputeDataType
|
||||
",fp16>"; // BComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",false>"; // DirectLoad
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_base.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer for Wmma variant
|
||||
TEST(GetInstanceString, ReturnsStringForFwdGrpConvWmmaInstance)
|
||||
{
|
||||
// Use the template helper to get a working instance configuration
|
||||
using InstanceTuple =
|
||||
ck::tensor_operation::device::instance::device_grouped_conv_fwd_wmma_f16_instances<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::Tuple<>, // DsDatatype
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementOp
|
||||
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvSpec
|
||||
|
||||
// Get the first instance from the tuple
|
||||
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
||||
|
||||
// Define the base class type using the most general operator base
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
|
||||
// Create an instance of the derived class
|
||||
DeviceInstance device_instance;
|
||||
|
||||
// Get a pointer to the base class
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
// Call GetInstanceString through the base class pointer
|
||||
std::string instance_str = base_ptr->GetInstanceString();
|
||||
|
||||
// Expected complete instance string based on the first instance from
|
||||
// device_grouped_conv_fwd_wmma_f16_instances
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",MNKPadding" // GemmSpec
|
||||
",1" // NumGemmKPrefetchStage
|
||||
",128" // BlockSize
|
||||
",64" // MPerBlock
|
||||
",64" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
",8" // K1
|
||||
",16" // MPerWmma
|
||||
",16" // NPerWmma
|
||||
",2" // MRepeat
|
||||
",2" // NRepeat
|
||||
",Seq(4,32,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",1" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",true" // ABlockLdsExtraM
|
||||
",Seq(4,32,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",1" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",true" // BBlockLdsExtraN
|
||||
",1" // CShuffleMRepeatPerShuffle
|
||||
",1" // CShuffleNRepeatPerShuffle
|
||||
",Seq(1,32,1,4)" // CDEShuffleBlockTransferClusterLengths
|
||||
",1" // CDEShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
",Default" // LoopSched
|
||||
",v1>"; // PipelineVer
|
||||
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
@@ -202,8 +202,8 @@ TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
|
||||
TEST(InstanceTraitsUtil, LoopSchedulerNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
EXPECT_THAT(std::vector<std::string_view> names = {loop_scheduler_name(Default),
|
||||
loop_scheduler_name(Interwave)},
|
||||
EXPECT_THAT((std::vector<std::string_view>{loop_scheduler_name(Default),
|
||||
loop_scheduler_name(Interwave)}),
|
||||
ElementsAre("Default", "Interwave"));
|
||||
}
|
||||
|
||||
@@ -267,5 +267,15 @@ TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForMultipleValueSequence)
|
||||
EXPECT_EQ((sequence_name<ck::Sequence<256, 128, 64, 32, 16>>()), "Seq(256,128,64,32,16)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TypeOrTypeTupleNameReturnsCorrectStringForScalarDataType)
|
||||
{
|
||||
EXPECT_EQ(type_or_type_tuple_name<float>(), "fp32");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TypeOrTypeTupleNameReturnsCorrectStringForTupleOfDataTypes)
|
||||
{
|
||||
EXPECT_EQ((type_or_type_tuple_name<ck::Tuple<ck::half_t, float>>()), "Tuple(fp16,fp32)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace ck_tile::reflect::detail
|
||||
|
||||
Reference in New Issue
Block a user