From 96eb0ef19392da5f42d703155d64a3915247553e Mon Sep 17 00:00:00 2001 From: John Shumway Date: Sat, 22 Nov 2025 17:18:49 +0000 Subject: [PATCH] Clean up convolution dispatcher The concepts for each factory implemenation are literally implementation details rather than domain concepts. Move that logic into helper functions to make it easier to study and refactor. --- .../builder/conv_algorithm_concepts.hpp | 37 ----- .../builder/factory/conv_dispatcher.hpp | 137 +++++++++++++++--- .../builder/factory/conv_fwd_dl_factory.hpp | 3 +- .../factory/conv_fwd_large_tensor_factory.hpp | 4 +- .../builder/factory/conv_fwd_v3_factory.hpp | 3 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 3 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 3 +- 7 files changed, 122 insertions(+), 68 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 9b82909bbb..f99d21db2d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -256,41 +256,4 @@ concept SpecifiesDlEpilogue = requires { { T::transfer.c.epilogue } -> DlEpilogueDescriptor; }; -/******************************************** */ -/* Concepts for the different device ops */ -/******************************************** */ - -template -concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesBlockGemm; - -template -concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && - SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; - -template -concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; - -template -concept DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; - -template -concept DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle && - SpecifiesLargeTensorSupport; - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 1d6f9405c5..a76bd0d9d2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -1,6 +1,52 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +// Compile-time dispatcher for convolution kernel instantiation. +// +// This header provides a centralized factory dispatch mechanism that routes algorithm +// specifications to appropriate convolution kernel implementations at compile-time. +// +// ## Design Overview +// +// The dispatcher operates in two phases: +// 1. **Algorithm Identification**: Five `consteval` predicate functions (`IsXdlV3Algorithm`, +// `IsXdlAlgorithm`, `IsWmmaAlgorithm`, `IsDlAlgorithm`, `IsLargeTensorAlgorithm`) inspect +// the algorithm descriptor's structure to determine which kernel variant it satisfies. +// Each predicate checks a specific set of concept constraints that define a kernel variant. +// +// 2. **Factory Routing**: The main `make_conv_instance()` function uses `if constexpr` +// to dispatch to the appropriate factory class based on both the convolution direction +// and the identified algorithm type. All routing decisions occur at compile-time, +// ensuring zero runtime overhead. +// +// ## Supported Kernel Variants +// +// - **XDL V3**: Newer XDL-based pipeline using block GEMM structure. Requires fewer parameters +// than standard XDL (e.g., uses `SpecifiesBlockGemm` instead of scheduling/prefetch configs). +// +// - **XDL**: Standard XDL-based kernel using AMD XDLops hardware instructions for matrix +// multiply. Requires full scheduling configuration including prefetch stages and loop scheduler. +// +// - **WMMA**: Wavefront Matrix-Matrix Accumulate variant optimized for WMMA-capable hardware. +// Requires similar configuration to XDL. +// +// - **DL**: Specialized Direct Load kernel optimized for specific data layouts (NHWC/KYXC/NHWK). +// Uses DL-specific configuration for thread mapping and epilogue. +// +// - **Large Tensor**: XDL-based kernel with extended tensor support. Wraps a base XDL algorithm +// and adds large tensor capabilities. +// +// ## Current Limitations +// +// Currently only forward convolution is supported. Backward data and backward weight convolution +// directions will fail at compile-time with informative static_assert messages. +// +// ## Usage Example +// +// ``` +// auto kernel = make_conv_instance(); +// ``` + #pragma once #include "ck_tile/builder/conv_signature_concepts.hpp" @@ -16,42 +62,97 @@ namespace ck_tile::builder::factory { -// Forward declaration of the dispatcher function -template -constexpr auto make_conv_instance(); +// This dispatch logic is rigid and confusing for users. Further, hides most of +// the great error messages from our concepts. +// +// Requirements for a good design: +// 1. Fall through is bad: inputs should get directly to an implementation +// if we are going to have good compiler errors. +// 2. Logic should be easy for library users to understand. +// 3. Logic should be easy to test, maintain, and extend. +// +// We should probably add explicit tags to the algorithm descriptors, at least +// for the initial implemenation. +// +// TODO: Make this dispatch logic much more robust and clear for users. + +// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) +template +consteval bool IsXdlV3Algorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesBlockGemm; +} + +// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply) +template +consteval bool IsXdlAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && + SpecifiesLoopScheduler; +} + +// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions) +template +consteval bool IsWmmaAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; +} + +// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts +template +consteval bool IsDlAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; +} + +// XDL-based kernel with large tensor support +template +consteval bool IsLargeTensorAlgorithm() +{ + return IsXdlAlgorithm() && SpecifiesLargeTensorSupport; +} -// Implementation of the dispatcher template constexpr auto make_conv_instance() { - // Check convolution direction if constexpr(ConvDirectionIsForward) { - // Forward convolution dispatch - // Check which algorithm concept the ALGORITHM satisfies using AlgoType = std::remove_const_t; - if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3) + if constexpr(IsXdlV3Algorithm()) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle) + else if constexpr(IsXdlAlgorithm()) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle) + else if constexpr(IsWmmaAlgorithm()) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK) + else if constexpr(IsDlAlgorithm()) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor) + else if constexpr(IsLargeTensorAlgorithm()) { return typename ConvFwdLargeTensorFactory::Instance{}; } @@ -60,12 +161,8 @@ constexpr auto make_conv_instance() static_assert( false, "No suitable forward convolution kernel factory found for the provided ALGORITHM. " - "The ALGORITHM must satisfy one of the following concepts: " - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, " - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, " - "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle, " - "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, or " - "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor."); + "The ALGORITHM must satisfy requirements for one of: XDL V3, XDL, WMMA, DL (NHWC " + "layout), or Large Tensor variant."); } } else if constexpr(ConvDirectionIsBackwardData) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index ba1690932a..dee918cc1f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -21,8 +21,7 @@ namespace ck_tile::builder::factory { template - requires ConvDirectionIsForward && DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< - std::remove_const_t> + requires ConvDirectionIsForward struct ConvFwdDlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 743148b277..c796ff9177 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -34,9 +34,7 @@ namespace ck_tile::builder::factory { template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< - std::remove_const_t> + requires ConvDirectionIsForward struct ConvFwdLargeTensorFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 57b51be9af..ad547e87c1 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -23,8 +23,7 @@ namespace ck_tile::builder::factory { template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + requires ConvDirectionIsForward struct ConvFwdXdlV3Factory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index f699e3d29e..9bb127ea8d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -23,8 +23,7 @@ namespace ck_tile::builder::factory { template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle> + requires ConvDirectionIsForward struct ConvFwdWmmaFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index b9f92c1fee..06e2d5f8ff 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -23,8 +23,7 @@ namespace ck_tile::builder::factory { template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle> + requires ConvDirectionIsForward struct ConvFwdXdlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;