Convert to dispatching through a function

The template specialization and resolution was not working. In particular, SFINAE caused all the helpful errors from asserts to be hidden. We make the dispatch explicit through a function and get rid of the template specialization.
This commit is contained in:
John Shumway
2025-11-22 00:27:27 +00:00
parent 304856c233
commit 67004105d4
10 changed files with 112 additions and 68 deletions

View File

@@ -6,7 +6,7 @@
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/factory/conv_factory.hpp"
#include "ck_tile/builder/factory/conv_dispatcher.hpp"
#include "ck_tile/builder/versions.hpp"
namespace ck_tile::builder {
@@ -15,7 +15,7 @@ namespace ck_tile::builder {
* @brief Top-level builder for creating convolution kernel instances.
*
* This struct serves as the main entry point for generating a convolution kernel.
* It uses a factory pattern based on the provided signature, algorithm, and version
* It uses a dispatcher function based on the provided signature, algorithm, and version
* to construct the appropriate kernel instance.
*
* @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of
@@ -30,9 +30,8 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBuilder
{
static constexpr auto kVersion = VERSION;
using Factory = ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
// Output: The kernel class.
using Instance = Factory::Instance;
// Output: The kernel class instance created via the dispatcher.
using Instance = decltype(make_conv_instance<SIGNATURE, ALGORITHM, VERSION>());
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,94 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/types.hpp"
// Include all factory implementations
#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
namespace ck_tile::builder {
// Forward declaration of the dispatcher function
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
constexpr auto make_conv_instance();
// Implementation of the dispatcher
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
constexpr auto make_conv_instance()
{
// Check convolution direction
if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
// Forward convolution dispatch
// Check which algorithm concept the ALGORITHM satisfies
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<AlgoType>)
{
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<AlgoType>)
{
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle<AlgoType>)
{
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<AlgoType>)
{
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<AlgoType>)
{
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else
{
static_assert(
false,
"No suitable forward convolution kernel factory found for the provided ALGORITHM. "
"The ALGORITHM must satisfy one of the following concepts: "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, "
"DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle, "
"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, or "
"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor.");
}
}
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
static_assert(
false,
"Backward data convolution is not yet supported. "
"Only forward convolution (ConvDirection::FORWARD) is currently implemented.");
}
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
static_assert(
false,
"Backward weight convolution is not yet supported. "
"Only forward convolution (ConvDirection::FORWARD) is currently implemented.");
}
else
{
static_assert(false,
"Invalid or unsupported convolution direction. "
"The SIGNATURE must specify a valid ConvDirection: FORWARD, BACKWARD_DATA, "
"or BACKWARD_WEIGHT.");
}
}
} // namespace ck_tile::builder

View File

@@ -1,47 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// A factory for instantiating CK convolution kernels.
//
// This file translates a semantic description of a convolution operation
// (`ConvSignatureDescriptor` and `ConvAlgorithmDescriptor`) into specific,
// low-level template arguments required by the underlying CK device-level
// kernel implementations. This abstraction enables more complex build
// time logic and simplifies the kernel specification.
//
// Key Components:
//
// Template Metaprogram:
// - ConvFactory: The main factory, with specializations for different
// convolution directions (currently only forward).
//
// The primary entry point is the `ConvFactory` struct, which is specialized
// for different forward convolution kernel types in separate header files.
#pragma once
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/versions.hpp"
namespace ck_tile::builder {
// Primary template for the convolution factory.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
auto VERSION>
struct ConvFactory
{
// This will trigger if a specialization for the given convolution direction is not found.
// We should always catch this in an earlier validation check.
static_assert(false, "Unsupported device operation.");
};
} // namespace ck_tile::builder
// Include all factory specializations
#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"

View File

@@ -16,14 +16,14 @@
namespace ck_tile::builder {
// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance
// of a grouped forward convolution kernel using Direct Load (DL) approach.
// Factory for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance
// of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> && DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
std::remove_const_t<decltype(ALGORITHM)>>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
struct ConvFwdDlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,

View File

@@ -29,15 +29,15 @@
namespace ck_tile::builder {
// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance
// of a grouped forward convolution kernel with large tensor support (N-splitting).
// Factory for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance
// of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
std::remove_const_t<decltype(ALGORITHM)>>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
struct ConvFwdLargeTensorFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,

View File

@@ -18,14 +18,14 @@
namespace ck_tile::builder {
// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance
// of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<decltype(ALGORITHM)>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
struct ConvFwdXdlV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,

View File

@@ -18,14 +18,14 @@
namespace ck_tile::builder {
// Factory specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle instance
// Factory for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle instance
// of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle<std::remove_const_t<decltype(ALGORITHM)>>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
struct ConvFwdWmmaFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,

View File

@@ -18,14 +18,14 @@
namespace ck_tile::builder {
// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
// of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<std::remove_const_t<decltype(ALGORITHM)>>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
struct ConvFwdXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,

View File

@@ -5,7 +5,6 @@
#include <concepts>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/factory/conv_factory.hpp>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/types.hpp>
@@ -680,15 +679,14 @@ struct ConvTraits<Instance>
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
/// @details This specialization provides backward compatibility for reflecting
/// on kernels defined via the `ConvBuilder` interface. It works by first
/// creating the `Instance` via the builder's factory, and then delegating
/// creating the `Instance` via the builder, and then delegating
/// all trait extraction to the `ConvTraits<Instance>` specialization.
template <builder::ConvSignatureDescriptor auto SIGNATURE,
builder::ConvAlgorithmDescriptor auto ALGORITHM,
builder::StringLiteral VERSION>
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
{
using Factory = builder::ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
using Instance = typename Factory::Instance;
using Instance = typename builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>::Instance;
// Delegate to Instance-based ConvTraits
using InstanceConvTraits = ConvTraits<Instance>;