mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Clean up convolution dispatcher
The concepts for each factory implemenation are literally implementation details rather than domain concepts. Move that logic into helper functions to make it easier to study and refactor.
This commit is contained in:
@@ -256,41 +256,4 @@ concept SpecifiesDlEpilogue = requires {
|
||||
{ T::transfer.c.epilogue } -> DlEpilogueDescriptor;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* Concepts for the different device ops */
|
||||
/******************************************** */
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
|
||||
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<decltype(T::base_algorithm)> &&
|
||||
SpecifiesLargeTensorSupport<T>;
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -1,6 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// Compile-time dispatcher for convolution kernel instantiation.
|
||||
//
|
||||
// This header provides a centralized factory dispatch mechanism that routes algorithm
|
||||
// specifications to appropriate convolution kernel implementations at compile-time.
|
||||
//
|
||||
// ## Design Overview
|
||||
//
|
||||
// The dispatcher operates in two phases:
|
||||
// 1. **Algorithm Identification**: Five `consteval` predicate functions (`IsXdlV3Algorithm`,
|
||||
// `IsXdlAlgorithm`, `IsWmmaAlgorithm`, `IsDlAlgorithm`, `IsLargeTensorAlgorithm`) inspect
|
||||
// the algorithm descriptor's structure to determine which kernel variant it satisfies.
|
||||
// Each predicate checks a specific set of concept constraints that define a kernel variant.
|
||||
//
|
||||
// 2. **Factory Routing**: The main `make_conv_instance()` function uses `if constexpr`
|
||||
// to dispatch to the appropriate factory class based on both the convolution direction
|
||||
// and the identified algorithm type. All routing decisions occur at compile-time,
|
||||
// ensuring zero runtime overhead.
|
||||
//
|
||||
// ## Supported Kernel Variants
|
||||
//
|
||||
// - **XDL V3**: Newer XDL-based pipeline using block GEMM structure. Requires fewer parameters
|
||||
// than standard XDL (e.g., uses `SpecifiesBlockGemm` instead of scheduling/prefetch configs).
|
||||
//
|
||||
// - **XDL**: Standard XDL-based kernel using AMD XDLops hardware instructions for matrix
|
||||
// multiply. Requires full scheduling configuration including prefetch stages and loop scheduler.
|
||||
//
|
||||
// - **WMMA**: Wavefront Matrix-Matrix Accumulate variant optimized for WMMA-capable hardware.
|
||||
// Requires similar configuration to XDL.
|
||||
//
|
||||
// - **DL**: Specialized Direct Load kernel optimized for specific data layouts (NHWC/KYXC/NHWK).
|
||||
// Uses DL-specific configuration for thread mapping and epilogue.
|
||||
//
|
||||
// - **Large Tensor**: XDL-based kernel with extended tensor support. Wraps a base XDL algorithm
|
||||
// and adds large tensor capabilities.
|
||||
//
|
||||
// ## Current Limitations
|
||||
//
|
||||
// Currently only forward convolution is supported. Backward data and backward weight convolution
|
||||
// directions will fail at compile-time with informative static_assert messages.
|
||||
//
|
||||
// ## Usage Example
|
||||
//
|
||||
// ```
|
||||
// auto kernel = make_conv_instance<my_signature, my_algorithm_descriptor, "v1">();
|
||||
// ```
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
@@ -16,42 +62,97 @@
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Forward declaration of the dispatcher function
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
constexpr auto make_conv_instance();
|
||||
// This dispatch logic is rigid and confusing for users. Further, hides most of
|
||||
// the great error messages from our concepts.
|
||||
//
|
||||
// Requirements for a good design:
|
||||
// 1. Fall through is bad: inputs should get directly to an implementation
|
||||
// if we are going to have good compiler errors.
|
||||
// 2. Logic should be easy for library users to understand.
|
||||
// 3. Logic should be easy to test, maintain, and extend.
|
||||
//
|
||||
// We should probably add explicit tags to the algorithm descriptors, at least
|
||||
// for the initial implemenation.
|
||||
//
|
||||
// TODO: Make this dispatch logic much more robust and clear for users.
|
||||
|
||||
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
|
||||
template <typename T>
|
||||
consteval bool IsXdlV3Algorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
}
|
||||
|
||||
// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
|
||||
template <typename T>
|
||||
consteval bool IsXdlAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
|
||||
SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
|
||||
// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions)
|
||||
template <typename T>
|
||||
consteval bool IsWmmaAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
|
||||
// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts
|
||||
template <typename T>
|
||||
consteval bool IsDlAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
}
|
||||
|
||||
// XDL-based kernel with large tensor support
|
||||
template <typename T>
|
||||
consteval bool IsLargeTensorAlgorithm()
|
||||
{
|
||||
return IsXdlAlgorithm<decltype(T::base_algorithm)>() && SpecifiesLargeTensorSupport<T>;
|
||||
}
|
||||
|
||||
// Implementation of the dispatcher
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
constexpr auto make_conv_instance()
|
||||
{
|
||||
// Check convolution direction
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
// Forward convolution dispatch
|
||||
// Check which algorithm concept the ALGORITHM satisfies
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
|
||||
if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<AlgoType>)
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<AlgoType>)
|
||||
else if constexpr(IsXdlAlgorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle<AlgoType>)
|
||||
else if constexpr(IsWmmaAlgorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<AlgoType>)
|
||||
else if constexpr(IsDlAlgorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<AlgoType>)
|
||||
else if constexpr(IsLargeTensorAlgorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
@@ -60,12 +161,8 @@ constexpr auto make_conv_instance()
|
||||
static_assert(
|
||||
false,
|
||||
"No suitable forward convolution kernel factory found for the provided ALGORITHM. "
|
||||
"The ALGORITHM must satisfy one of the following concepts: "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, "
|
||||
"DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle, "
|
||||
"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, or "
|
||||
"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor.");
|
||||
"The ALGORITHM must satisfy requirements for one of: XDL V3, XDL, WMMA, DL (NHWC "
|
||||
"layout), or Large Tensor variant.");
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
|
||||
@@ -21,8 +21,7 @@ namespace ck_tile::builder::factory {
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> && DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
std::remove_const_t<decltype(ALGORITHM)>>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
|
||||
@@ -34,9 +34,7 @@ namespace ck_tile::builder::factory {
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
std::remove_const_t<decltype(ALGORITHM)>>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdLargeTensorFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
|
||||
@@ -23,8 +23,7 @@ namespace ck_tile::builder::factory {
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<decltype(ALGORITHM)>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
|
||||
@@ -23,8 +23,7 @@ namespace ck_tile::builder::factory {
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle<std::remove_const_t<decltype(ALGORITHM)>>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
|
||||
@@ -23,8 +23,7 @@ namespace ck_tile::builder::factory {
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<std::remove_const_t<decltype(ALGORITHM)>>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
|
||||
Reference in New Issue
Block a user