diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 9b82909bbb..f99d21db2d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -256,41 +256,4 @@ concept SpecifiesDlEpilogue = requires { { T::transfer.c.epilogue } -> DlEpilogueDescriptor; }; -/******************************************** */ -/* Concepts for the different device ops */ -/******************************************** */ - -template -concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesBlockGemm; - -template -concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && - SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; - -template -concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; - -template -concept DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; - -template -concept DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle && - SpecifiesLargeTensorSupport; - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 1d6f9405c5..a76bd0d9d2 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -1,6 +1,52 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +// Compile-time dispatcher for convolution kernel instantiation. +// +// This header provides a centralized factory dispatch mechanism that routes algorithm +// specifications to appropriate convolution kernel implementations at compile-time. +// +// ## Design Overview +// +// The dispatcher operates in two phases: +// 1. **Algorithm Identification**: Five `consteval` predicate functions (`IsXdlV3Algorithm`, +// `IsXdlAlgorithm`, `IsWmmaAlgorithm`, `IsDlAlgorithm`, `IsLargeTensorAlgorithm`) inspect +// the algorithm descriptor's structure to determine which kernel variant it satisfies. +// Each predicate checks a specific set of concept constraints that define a kernel variant. +// +// 2. **Factory Routing**: The main `make_conv_instance()` function uses `if constexpr` +// to dispatch to the appropriate factory class based on both the convolution direction +// and the identified algorithm type. All routing decisions occur at compile-time, +// ensuring zero runtime overhead. +// +// ## Supported Kernel Variants +// +// - **XDL V3**: Newer XDL-based pipeline using block GEMM structure. Requires fewer parameters +// than standard XDL (e.g., uses `SpecifiesBlockGemm` instead of scheduling/prefetch configs). +// +// - **XDL**: Standard XDL-based kernel using AMD XDLops hardware instructions for matrix +// multiply. Requires full scheduling configuration including prefetch stages and loop scheduler. +// +// - **WMMA**: Wavefront Matrix-Matrix Accumulate variant optimized for WMMA-capable hardware. +// Requires similar configuration to XDL. +// +// - **DL**: Specialized Direct Load kernel optimized for specific data layouts (NHWC/KYXC/NHWK). +// Uses DL-specific configuration for thread mapping and epilogue. +// +// - **Large Tensor**: XDL-based kernel with extended tensor support. Wraps a base XDL algorithm +// and adds large tensor capabilities. +// +// ## Current Limitations +// +// Currently only forward convolution is supported. Backward data and backward weight convolution +// directions will fail at compile-time with informative static_assert messages. +// +// ## Usage Example +// +// ``` +// auto kernel = make_conv_instance(); +// ``` + #pragma once #include "ck_tile/builder/conv_signature_concepts.hpp" @@ -16,42 +62,97 @@ namespace ck_tile::builder::factory { -// Forward declaration of the dispatcher function -template -constexpr auto make_conv_instance(); +// This dispatch logic is rigid and confusing for users. Further, hides most of +// the great error messages from our concepts. +// +// Requirements for a good design: +// 1. Fall through is bad: inputs should get directly to an implementation +// if we are going to have good compiler errors. +// 2. Logic should be easy for library users to understand. +// 3. Logic should be easy to test, maintain, and extend. +// +// We should probably add explicit tags to the algorithm descriptors, at least +// for the initial implemenation. +// +// TODO: Make this dispatch logic much more robust and clear for users. + +// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) +template +consteval bool IsXdlV3Algorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesBlockGemm; +} + +// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply) +template +consteval bool IsXdlAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && + SpecifiesLoopScheduler; +} + +// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions) +template +consteval bool IsWmmaAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; +} + +// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts +template +consteval bool IsDlAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; +} + +// XDL-based kernel with large tensor support +template +consteval bool IsLargeTensorAlgorithm() +{ + return IsXdlAlgorithm() && SpecifiesLargeTensorSupport; +} -// Implementation of the dispatcher template constexpr auto make_conv_instance() { - // Check convolution direction if constexpr(ConvDirectionIsForward) { - // Forward convolution dispatch - // Check which algorithm concept the ALGORITHM satisfies using AlgoType = std::remove_const_t; - if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3) + if constexpr(IsXdlV3Algorithm()) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle) + else if constexpr(IsXdlAlgorithm()) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle) + else if constexpr(IsWmmaAlgorithm()) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK) + else if constexpr(IsDlAlgorithm()) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor) + else if constexpr(IsLargeTensorAlgorithm()) { return typename ConvFwdLargeTensorFactory::Instance{}; } @@ -60,12 +161,8 @@ constexpr auto make_conv_instance() static_assert( false, "No suitable forward convolution kernel factory found for the provided ALGORITHM. " - "The ALGORITHM must satisfy one of the following concepts: " - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, " - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, " - "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle, " - "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, or " - "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor."); + "The ALGORITHM must satisfy requirements for one of: XDL V3, XDL, WMMA, DL (NHWC " + "layout), or Large Tensor variant."); } } else if constexpr(ConvDirectionIsBackwardData) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index ba1690932a..dee918cc1f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -21,8 +21,7 @@ namespace ck_tile::builder::factory { template - requires ConvDirectionIsForward && DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< - std::remove_const_t> + requires ConvDirectionIsForward struct ConvFwdDlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 743148b277..c796ff9177 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -34,9 +34,7 @@ namespace ck_tile::builder::factory { template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< - std::remove_const_t> + requires ConvDirectionIsForward struct ConvFwdLargeTensorFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 57b51be9af..ad547e87c1 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -23,8 +23,7 @@ namespace ck_tile::builder::factory { template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + requires ConvDirectionIsForward struct ConvFwdXdlV3Factory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index f699e3d29e..9bb127ea8d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -23,8 +23,7 @@ namespace ck_tile::builder::factory { template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle> + requires ConvDirectionIsForward struct ConvFwdWmmaFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index b9f92c1fee..06e2d5f8ff 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -23,8 +23,7 @@ namespace ck_tile::builder::factory { template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle> + requires ConvDirectionIsForward struct ConvFwdXdlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;