[CK_BUILDER] Refactor builder factory code. (#3276)

Refactor the builder factory code into multiple files and subdirectories and a ck_tile::builder::factory namespace.

The factory implements compile-time dispatch from high-level signature and algorithm descriptors to our existing specialized convolution kernel implementations.

Major changes in this PR:

Dispatch logic is explicit in the function make_conv_instance instead of implicit in template specialization selection.
Helper code is moved to a subdirectory builder/factory/helpers.
Helpers now have unit tests.
Factories are moved to their own files.
Code moved to namespaces ck_tile::builder::factory and ck_tile::builder::factory::internal.
This does not yet fix the problem of bad error messages, but the make_conv_instance function makes the poor error messages clear. The choice of algorithm must be much more robust (perhaps with explicit enumeration in the algorithm descriptor), so that the dispatch doesn't fail.

Quality changes:

Making dispatch explicit rather than implicit will improve robustness, readability, maintainability, testability, and extensibility.
Separating code into separate files and subdirectories helps readability and extensibility.
Adding unit tests for helpers documents behavior and will enable more complex logic and functionality.
Separating files (especially unit tests) helps clarify includes and dependencies and makes code easier to refactor.
This commit is contained in:
John Shumway
2025-12-02 07:40:14 -08:00
committed by GitHub
parent 8459d389ad
commit 280bc42191
29 changed files with 1782 additions and 1134 deletions

View File

@@ -0,0 +1,90 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
namespace {
namespace ckb = ::ck_tile::builder;
using namespace ck_tile::builder;
using namespace ck_tile::builder::factory::internal;
TEST(ConvTuningParams, AssignsBlockGemmParams)
{
constexpr struct Algorithm
{
struct BlockGemm
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3;
ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE;
} block_gemm;
} kAlgorithm;
constexpr auto block_gemm = SetBlockGemm<kAlgorithm>();
EXPECT_EQ(block_gemm.pipeline_version, ck::BlockGemmPipelineVersion::v3);
EXPECT_EQ(block_gemm.scheduler, ck::BlockGemmPipelineScheduler::Intrawave);
}
TEST(ConvTuningParams, AssignsLoopSchedulerParam)
{
constexpr struct Algorithm
{
ckb::PipelineScheduler loop_scheduler = ckb::PipelineScheduler::INTERWAVE;
} kAlgorithm;
constexpr auto loop_scheduler = SetLoopScheduler<kAlgorithm>();
EXPECT_EQ(loop_scheduler, ck::LoopScheduler::Interwave);
}
TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion)
{
constexpr struct Algorithm
{
struct GridwiseGemm
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
} gridwise_gemm;
} kAlgorithm;
constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion<kAlgorithm>();
EXPECT_EQ(pipeline_version, ck::PipelineVersion::v4);
}
TEST(ConvTuningParams, AssignsGemmSpecialization)
{
constexpr struct Algorithm
{
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::MNKPadding;
} kAlgorithm;
constexpr auto gemm_spec = SetGemmSpecialization<kAlgorithm>();
EXPECT_EQ(gemm_spec, ck::tensor_operation::device::GemmSpecialization::MNKPadding);
}
TEST(ConvTuningParams, AssignsBlockGemmPipelineVersion)
{
constexpr struct Algorithm
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V2;
} kAlgorithm;
constexpr auto pipeline_version = SetBlockGemmPipelineVersion<kAlgorithm>();
EXPECT_EQ(pipeline_version, ck::BlockGemmPipelineVersion::v2);
}
TEST(ConvTuningParams, AssignsFwdConvSpecialization)
{
constexpr struct Algorithm
{
ckb::ConvFwdSpecialization fwd_specialization =
ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
} kAlgorithm;
constexpr auto conv_spec = SetFwdConvSpecialization<kAlgorithm>();
EXPECT_EQ(conv_spec,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
}
} // namespace