From 67004105d4f9370aeef7cedc1c535d4e28436909 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Sat, 22 Nov 2025 00:27:27 +0000 Subject: [PATCH] Convert to dispatching through a function The template specialization and resolution was not working. In particular, SFINAE caused all the helpful errors from asserts to be hidden. We make the dispatch explicit through a function and get rid of the template specialization. --- .../include/ck_tile/builder/conv_builder.hpp | 9 +- .../builder/factory/conv_dispatcher.hpp | 94 +++++++++++++++++++ .../ck_tile/builder/factory/conv_factory.hpp | 47 ---------- .../builder/factory/conv_fwd_dl_factory.hpp | 6 +- .../factory/conv_fwd_large_tensor_factory.hpp | 6 +- .../builder/factory/conv_fwd_v3_factory.hpp | 4 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 4 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 4 +- .../ck_tile/builder/reflect/conv_traits.hpp | 6 +- .../builder/test/test_ckb_conv_builder.cpp | 0 10 files changed, 112 insertions(+), 68 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp delete mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_factory.hpp create mode 100644 experimental/builder/test/test_ckb_conv_builder.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_builder.hpp b/experimental/builder/include/ck_tile/builder/conv_builder.hpp index 78f9d9b7c4..2a105406e0 100644 --- a/experimental/builder/include/ck_tile/builder/conv_builder.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_builder.hpp @@ -6,7 +6,7 @@ #include #include -#include "ck_tile/builder/factory/conv_factory.hpp" +#include "ck_tile/builder/factory/conv_dispatcher.hpp" #include "ck_tile/builder/versions.hpp" namespace ck_tile::builder { @@ -15,7 +15,7 @@ namespace ck_tile::builder { * @brief Top-level builder for creating convolution kernel instances. * * This struct serves as the main entry point for generating a convolution kernel. - * It uses a factory pattern based on the provided signature, algorithm, and version + * It uses a dispatcher function based on the provided signature, algorithm, and version * to construct the appropriate kernel instance. * * @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of @@ -30,9 +30,8 @@ template ; - // Output: The kernel class. - using Instance = Factory::Instance; + // Output: The kernel class instance created via the dispatcher. + using Instance = decltype(make_conv_instance()); }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp new file mode 100644 index 0000000000..7adf0b5305 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -0,0 +1,94 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/types.hpp" + +// Include all factory implementations +#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp" +#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" +#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" + +namespace ck_tile::builder { + +// Forward declaration of the dispatcher function +template +constexpr auto make_conv_instance(); + +// Implementation of the dispatcher +template +constexpr auto make_conv_instance() +{ + // Check convolution direction + if constexpr(ConvDirectionIsForward) + { + // Forward convolution dispatch + // Check which algorithm concept the ALGORITHM satisfies + using AlgoType = std::remove_const_t; + + if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3) + { + return typename ConvFwdXdlV3Factory::Instance{}; + } + else if constexpr(DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle) + { + return typename ConvFwdXdlFactory::Instance{}; + } + else if constexpr(DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle) + { + return typename ConvFwdWmmaFactory::Instance{}; + } + else if constexpr(DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK) + { + return typename ConvFwdDlFactory::Instance{}; + } + else if constexpr(DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor) + { + return typename ConvFwdLargeTensorFactory::Instance{}; + } + else + { + static_assert( + false, + "No suitable forward convolution kernel factory found for the provided ALGORITHM. " + "The ALGORITHM must satisfy one of the following concepts: " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, " + "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle, " + "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, or " + "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor."); + } + } + else if constexpr(ConvDirectionIsBackwardData) + { + static_assert( + false, + "Backward data convolution is not yet supported. " + "Only forward convolution (ConvDirection::FORWARD) is currently implemented."); + } + else if constexpr(ConvDirectionIsBackwardWeight) + { + static_assert( + false, + "Backward weight convolution is not yet supported. " + "Only forward convolution (ConvDirection::FORWARD) is currently implemented."); + } + else + { + static_assert(false, + "Invalid or unsupported convolution direction. " + "The SIGNATURE must specify a valid ConvDirection: FORWARD, BACKWARD_DATA, " + "or BACKWARD_WEIGHT."); + } +} + +} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_factory.hpp deleted file mode 100644 index 53554753b3..0000000000 --- a/experimental/builder/include/ck_tile/builder/factory/conv_factory.hpp +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -// A factory for instantiating CK convolution kernels. -// -// This file translates a semantic description of a convolution operation -// (`ConvSignatureDescriptor` and `ConvAlgorithmDescriptor`) into specific, -// low-level template arguments required by the underlying CK device-level -// kernel implementations. This abstraction enables more complex build -// time logic and simplifies the kernel specification. -// -// Key Components: -// -// Template Metaprogram: -// - ConvFactory: The main factory, with specializations for different -// convolution directions (currently only forward). -// -// The primary entry point is the `ConvFactory` struct, which is specialized -// for different forward convolution kernel types in separate header files. - -#pragma once - -#include "ck_tile/builder/conv_signature_concepts.hpp" -#include "ck_tile/builder/conv_algorithm_concepts.hpp" -#include "ck_tile/builder/versions.hpp" - -namespace ck_tile::builder { - -// Primary template for the convolution factory. -template -struct ConvFactory -{ - // This will trigger if a specialization for the given convolution direction is not found. - // We should always catch this in an earlier validation check. - static_assert(false, "Unsupported device operation."); -}; - -} // namespace ck_tile::builder - -// Include all factory specializations -#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp" -#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp" -#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp" -#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" -#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index 98594a039b..61b59863b6 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -16,14 +16,14 @@ namespace ck_tile::builder { -// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance -// of a grouped forward convolution kernel using Direct Load (DL) approach. +// Factory for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance +// of a grouped forward convolution kernel. template requires ConvDirectionIsForward && DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< std::remove_const_t> -struct ConvFactory +struct ConvFwdDlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = decltype(factory_internal::GetTensorLayout requires ConvDirectionIsForward && DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< std::remove_const_t> -struct ConvFactory +struct ConvFwdLargeTensorFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = decltype(factory_internal::GetTensorLayout requires ConvDirectionIsForward && DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 -struct ConvFactory +struct ConvFwdXdlV3Factory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = decltype(factory_internal::GetTensorLayout requires ConvDirectionIsForward && DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle> -struct ConvFactory +struct ConvFwdWmmaFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = decltype(factory_internal::GetTensorLayout requires ConvDirectionIsForward && DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle> -struct ConvFactory +struct ConvFwdXdlFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; using Layouts = decltype(factory_internal::GetTensorLayout #include -#include #include #include #include @@ -680,15 +679,14 @@ struct ConvTraits /// @brief Specialization of `ConvTraits` for a `ConvBuilder` type. /// @details This specialization provides backward compatibility for reflecting /// on kernels defined via the `ConvBuilder` interface. It works by first -/// creating the `Instance` via the builder's factory, and then delegating +/// creating the `Instance` via the builder, and then delegating /// all trait extraction to the `ConvTraits` specialization. template struct ConvTraits> { - using Factory = builder::ConvFactory; - using Instance = typename Factory::Instance; + using Instance = typename builder::ConvBuilder::Instance; // Delegate to Instance-based ConvTraits using InstanceConvTraits = ConvTraits; diff --git a/experimental/builder/test/test_ckb_conv_builder.cpp b/experimental/builder/test/test_ckb_conv_builder.cpp new file mode 100644 index 0000000000..e69de29bb2