[CK_BUILDER] Convolution traits. (#3152)

Added:

1. Convolution traits & unit tests
2. Update builder enumerators to have representation of Convolution Kernels properties.
3. Unified builder pipeline version & scheduler enumerators
This commit is contained in:
Adam Osewski
2025-11-05 17:53:06 +01:00
committed by GitHub
parent 3b076b0b74
commit b8527a9236
20 changed files with 1165 additions and 81 deletions

View File

@@ -16,7 +16,7 @@ using namespace test;
// Common test implementation
template <ConvSignature FwdConvSignature,
ThreadBlock FwdThreadBlock,
BlockGemmPipelineVersion FwdPipelineVersion,
PipelineVersion FwdPipelineVersion,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
{
@@ -52,7 +52,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
.src_access_order_b = {1, 0, 2}};
constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion,
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock,
@@ -73,13 +73,13 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
// Verify pipeline version is correct
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
if(FwdPipelineVersion == PipelineVersion::V1)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
else if(FwdPipelineVersion == PipelineVersion::V3)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
else if(FwdPipelineVersion == PipelineVersion::V4)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
else if(FwdPipelineVersion == PipelineVersion::V5)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
// Verify specialization is correct
@@ -140,7 +140,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle()
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages = 1,
.num_groups_to_merge = 2,
.loop_scheduler = LoopScheduler::DEFAULT};
.loop_scheduler = PipelineScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -176,7 +176,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
.n_per_wmma = 32,
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1,
.pipeline_version = GridwiseGemmPipelineVersion::V1};
.pipeline_version = PipelineVersion::V1};
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
@@ -209,7 +209,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
.fwd_specialization = FwdConvSpecialization,
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages = 1,
.loop_scheduler = LoopScheduler::DEFAULT};
.loop_scheduler = PipelineScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;