mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Convert to dispatching through a function
The template specialization and resolution was not working. In particular, SFINAE caused all the helpful errors from asserts to be hidden. We make the dispatch explicit through a function and get rid of the template specialization.
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/factory/conv_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_dispatcher.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
@@ -15,7 +15,7 @@ namespace ck_tile::builder {
|
||||
* @brief Top-level builder for creating convolution kernel instances.
|
||||
*
|
||||
* This struct serves as the main entry point for generating a convolution kernel.
|
||||
* It uses a factory pattern based on the provided signature, algorithm, and version
|
||||
* It uses a dispatcher function based on the provided signature, algorithm, and version
|
||||
* to construct the appropriate kernel instance.
|
||||
*
|
||||
* @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of
|
||||
@@ -30,9 +30,8 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvBuilder
|
||||
{
|
||||
static constexpr auto kVersion = VERSION;
|
||||
using Factory = ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
|
||||
// Output: The kernel class.
|
||||
using Instance = Factory::Instance;
|
||||
// Output: The kernel class instance created via the dispatcher.
|
||||
using Instance = decltype(make_conv_instance<SIGNATURE, ALGORITHM, VERSION>());
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
// Include all factory implementations
|
||||
#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Forward declaration of the dispatcher function
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
constexpr auto make_conv_instance();
|
||||
|
||||
// Implementation of the dispatcher
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
constexpr auto make_conv_instance()
|
||||
{
|
||||
// Check convolution direction
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
// Forward convolution dispatch
|
||||
// Check which algorithm concept the ALGORITHM satisfies
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
|
||||
if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"No suitable forward convolution kernel factory found for the provided ALGORITHM. "
|
||||
"The ALGORITHM must satisfy one of the following concepts: "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, "
|
||||
"DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle, "
|
||||
"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, or "
|
||||
"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor.");
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"Backward data convolution is not yet supported. "
|
||||
"Only forward convolution (ConvDirection::FORWARD) is currently implemented.");
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"Backward weight convolution is not yet supported. "
|
||||
"Only forward convolution (ConvDirection::FORWARD) is currently implemented.");
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false,
|
||||
"Invalid or unsupported convolution direction. "
|
||||
"The SIGNATURE must specify a valid ConvDirection: FORWARD, BACKWARD_DATA, "
|
||||
"or BACKWARD_WEIGHT.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -1,47 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// A factory for instantiating CK convolution kernels.
|
||||
//
|
||||
// This file translates a semantic description of a convolution operation
|
||||
// (`ConvSignatureDescriptor` and `ConvAlgorithmDescriptor`) into specific,
|
||||
// low-level template arguments required by the underlying CK device-level
|
||||
// kernel implementations. This abstraction enables more complex build
|
||||
// time logic and simplifies the kernel specification.
|
||||
//
|
||||
// Key Components:
|
||||
//
|
||||
// Template Metaprogram:
|
||||
// - ConvFactory: The main factory, with specializations for different
|
||||
// convolution directions (currently only forward).
|
||||
//
|
||||
// The primary entry point is the `ConvFactory` struct, which is specialized
|
||||
// for different forward convolution kernel types in separate header files.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Primary template for the convolution factory.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
auto VERSION>
|
||||
struct ConvFactory
|
||||
{
|
||||
// This will trigger if a specialization for the given convolution direction is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(false, "Unsupported device operation.");
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
// Include all factory specializations
|
||||
#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
@@ -16,14 +16,14 @@
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance
|
||||
// of a grouped forward convolution kernel using Direct Load (DL) approach.
|
||||
// Factory for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> && DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
std::remove_const_t<decltype(ALGORITHM)>>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
struct ConvFwdDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
|
||||
|
||||
@@ -29,15 +29,15 @@
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance
|
||||
// of a grouped forward convolution kernel with large tensor support (N-splitting).
|
||||
// Factory for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
std::remove_const_t<decltype(ALGORITHM)>>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
struct ConvFwdLargeTensorFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
|
||||
|
||||
@@ -18,14 +18,14 @@
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance
|
||||
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<decltype(ALGORITHM)>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
struct ConvFwdXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
|
||||
|
||||
@@ -18,14 +18,14 @@
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle instance
|
||||
// Factory for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle<std::remove_const_t<decltype(ALGORITHM)>>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
struct ConvFwdWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
|
||||
|
||||
@@ -18,14 +18,14 @@
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
|
||||
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<std::remove_const_t<decltype(ALGORITHM)>>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
struct ConvFwdXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
#include <concepts>
|
||||
#include <ck_tile/builder/conv_builder.hpp>
|
||||
#include <ck_tile/builder/factory/conv_factory.hpp>
|
||||
#include <ck_tile/builder/conv_signature_concepts.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/types.hpp>
|
||||
@@ -680,15 +679,14 @@ struct ConvTraits<Instance>
|
||||
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
|
||||
/// @details This specialization provides backward compatibility for reflecting
|
||||
/// on kernels defined via the `ConvBuilder` interface. It works by first
|
||||
/// creating the `Instance` via the builder's factory, and then delegating
|
||||
/// creating the `Instance` via the builder, and then delegating
|
||||
/// all trait extraction to the `ConvTraits<Instance>` specialization.
|
||||
template <builder::ConvSignatureDescriptor auto SIGNATURE,
|
||||
builder::ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
builder::StringLiteral VERSION>
|
||||
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
|
||||
{
|
||||
using Factory = builder::ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
|
||||
using Instance = typename Factory::Instance;
|
||||
using Instance = typename builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>::Instance;
|
||||
|
||||
// Delegate to Instance-based ConvTraits
|
||||
using InstanceConvTraits = ConvTraits<Instance>;
|
||||
|
||||
0
experimental/builder/test/test_ckb_conv_builder.cpp
Normal file
0
experimental/builder/test/test_ckb_conv_builder.cpp
Normal file
Reference in New Issue
Block a user