[CK_BUILDER] Add bwd weight factories (#3509)

* Add placeholder test.

* Initial conv bwd weight factory.

* Conv builder test refactoring.

* Add missing pieces to bwd weight factory.

* Improve compile time erros message when no matching factory is found.

* Use amcro to ensure automatic macthing between concepts are their string representations.

* Improve compile time diagnostics.

* Small improvements.

* Improve missing member/wrong type compile-time errors.

* Improve compile time diagnostics.

* Concept bug fixes.

* Remove debug assert.

* Update algorithm signature diagnostics.

* Factory bug fixes.

* First functional version of bwd weight conv factory.

* Refactor handing of GEMM-K batch template parameter in conv bwd weight factory.

* Concept improvements.

* Improve concept diagnostics.

* Introduve a common size type for concepts.

* Update compiletime diagnostics to use the size type.

* Update conv specialization enum.

* Fix fwd conv builder tests.

* Fix smoke tests.

* Separate bwd weigth and bwd data tests into separate targets.

* Clean-up CK Tile builder tests.

* Add bwd weight XDL CShuffle V3 factory.

* Build conv bwd weigth v3 instances successfully.

* Add instance traits for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3.

* Test fix.

* Add instance traits for bwd weight algorithms.

* Add unit tests for instance strings.

* Build new instance traits unit tests but exclude WMMA for now.

* Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.

* Conv bwd weight DL factory.

* Final implementation for bwd weight DL factory.

* Add test for creating DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance.

* Add factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle

* Treat ref algorithm the same way as real algorithms in the dispatcher.

* Refactor large tensor support and WMMA configuration.

* Add factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffleV3.

* Update Readme.

* Fix WMMA bwd weight tests.

* Added factory and tests for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3.

* Factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffle.

* Dispatching for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle.

* Add factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3

* Fix DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 factory and  compute types for input and output tensor in bwd weigth convs.

* Fix fwd factories after refactoring.

* clang-format

* Move compile-time diagnostics to a separate branch.

* Fix ref algorithm dispatching.

* Fix smoke tests.

* clang-format

* Fix factory for regular WMMA conv bwd weight.

* Clarify builder Readme.

* Remove obsolete test file.

* Fix test after merge.

* clang-format

* Remove the C++26 extensions.

* Unify conv elementwise ops and layout definitions for fwd and bwd directions.

* Remove old layout and elementwise ops.

* Unify handling of conv tensor types between fwd and bwd directions.

* Unify block transfer for fwd and bwd directions. Rename ThreadSliceDim to ThreadClusterRank.

* Make BlockTransferDescriptor concept parametrized. Introduce a common TileTransferParameters concept for conv algorithms.

* clang-format

---------

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2026-01-13 18:12:38 +02:00
committed by GitHub
parent 710fa1fd3d
commit 9908a87c31
69 changed files with 2956 additions and 832 deletions

View File

@@ -19,7 +19,7 @@ TEST(ConvTuningParams, AssignsBlockGemmParams)
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3;
ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE;
} block_gemm;
} block_gemm_pipeline;
} kAlgorithm;
constexpr auto block_gemm = SetBlockGemm<kAlgorithm>();
@@ -42,10 +42,7 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion)
{
constexpr struct Algorithm
{
struct GridwiseGemm
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
} gridwise_gemm;
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
} kAlgorithm;
constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion<kAlgorithm>();
@@ -78,8 +75,8 @@ TEST(ConvTuningParams, AssignsFwdConvSpecialization)
{
constexpr struct Algorithm
{
ckb::ConvFwdSpecialization fwd_specialization =
ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
ckb::ConvSpecialization fwd_specialization =
ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0;
} kAlgorithm;
constexpr auto conv_spec = SetFwdConvSpecialization<kAlgorithm>();