[CK_BUILDER] Convolution traits. (#3152)

Added:

1. Convolution traits & unit tests
2. Update builder enumerators to have representation of Convolution Kernels properties.
3. Unified builder pipeline version & scheduler enumerators
This commit is contained in:
Adam Osewski
2025-11-05 17:53:06 +01:00
committed by GitHub
parent 3b076b0b74
commit b8527a9236
20 changed files with 1165 additions and 81 deletions

View File

@@ -27,7 +27,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V2,
PipelineVersion::V2,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
PipelineVersion::V1,
ConvFwdSpecialization::DEFAULT>();
}
@@ -47,7 +47,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V5,
PipelineVersion::V5,
ConvFwdSpecialization::FILTER_3x3>();
}

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
PipelineVersion::V3,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
PipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
PipelineVersion::V3,
ConvFwdSpecialization::DEFAULT>();
}

View File

@@ -26,7 +26,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
PipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -26,7 +26,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
PipelineVersion::V1,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -0,0 +1,316 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <concepts>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
namespace {
using ::testing::ElementsAre;
// Test fixture for ConvTraits tests
class ConvTraitsTest : public ::testing::Test
{
};
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::PipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
false>; // DirectLoad
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
// Verify A tile transfer info
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding);
// Verify B tile transfer info
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding);
// Verify warp GEMM params
EXPECT_EQ(Traits::warp_gemm.gemm_m, 32);
EXPECT_EQ(Traits::warp_gemm.gemm_n, 32);
EXPECT_EQ(Traits::warp_gemm.m_iter, 4);
EXPECT_EQ(Traits::warp_gemm.n_iter, 4);
// Verify output tile transfer info
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8);
// Verify pipeline configuration
EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE);
EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default, // LoopSched
1>; // NumGroupsToMerge
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default>; // LoopSched
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
}
} // anonymous namespace