mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Merge commit '280bc4219151c3f79fe8ca076a2d10df4ff88b34' into develop
This commit is contained in:
@@ -84,7 +84,7 @@ concept LdsTransferDescriptor = requires(T t) {
|
||||
// LDS).
|
||||
template <typename T>
|
||||
concept EpilogueDescriptor = requires(T t) {
|
||||
{ t.m_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.n_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
@@ -256,41 +256,4 @@ concept SpecifiesDlEpilogue = requires {
|
||||
{ T::transfer.c.epilogue } -> DlEpilogueDescriptor;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* Concepts for the different device ops */
|
||||
/******************************************** */
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
|
||||
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<decltype(T::base_algorithm)> &&
|
||||
SpecifiesLargeTensorSupport<T>;
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -18,8 +18,8 @@ concept InputVectorTransferLimits = requires {
|
||||
// Limits for output vector transfer.
|
||||
template <auto Value>
|
||||
concept OutputVectorTransferLimits = requires {
|
||||
requires Value.scalar_per_vector > 0 && Value.m_per_wave_per_shuffle > 0 &&
|
||||
Value.n_per_wave_per_shuffle > 0;
|
||||
requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 &&
|
||||
Value.n_xdl_per_wave_per_shuffle > 0;
|
||||
};
|
||||
|
||||
// Limits for access order. Must be a permutation of {0, 1, 2}.
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/conv_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_dispatcher.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
@@ -15,7 +15,7 @@ namespace ck_tile::builder {
|
||||
* @brief Top-level builder for creating convolution kernel instances.
|
||||
*
|
||||
* This struct serves as the main entry point for generating a convolution kernel.
|
||||
* It uses a factory pattern based on the provided signature, algorithm, and version
|
||||
* It uses a dispatcher function based on the provided signature, algorithm, and version
|
||||
* to construct the appropriate kernel instance.
|
||||
*
|
||||
* @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of
|
||||
@@ -30,9 +30,8 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvBuilder
|
||||
{
|
||||
static constexpr auto kVersion = VERSION;
|
||||
using Factory = ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
|
||||
// Output: The kernel class.
|
||||
using Instance = Factory::Instance;
|
||||
// Output: The kernel class instance created via the dispatcher.
|
||||
using Instance = decltype(factory::make_conv_instance<SIGNATURE, ALGORITHM, VERSION>());
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,31 @@
|
||||
# Convolution Builder Factory Directory
|
||||
|
||||
This directory implements compile-time dispatch from high-level signature algorithm descriptors to our exisitng specialized convolution kernel implementations.
|
||||
|
||||
See the [main builder documentation](../README.md) for an overview.
|
||||
|
||||
## Design Overview
|
||||
|
||||
The factory system operates in two phases:
|
||||
|
||||
1. **Algorithm Classification**: The function `make_conv_instance` in `conv_dispatcher.hpp` inspects the signature and algorithm descriptors to determine which kernel variant they satisfy (XDL V3, XDL, WMMA, DL, or Large Tensor)
|
||||
|
||||
2. **Factory Instantiation**: Each factory (`conv_fwd_*_factory.hpp`) transforms builder descriptors into CK device operation template parameters and instantiates the corresponding kernel device operation.
|
||||
|
||||
## Key Files
|
||||
|
||||
- **`conv_dispatcher.hpp`**: Entry point with `make_conv_instance()` function. Contains dispatch logic and algorithm classification predicates. **Start here** to understand the overall flow.
|
||||
|
||||
- **`conv_fwd_*_factory.hpp`**: Individual factories for each kernel variant. Each extracts configuration from descriptors, validates parameters, and instantiates the underlying CK device operation.
|
||||
|
||||
- **`helpers/`**: Transformation utilities that map builder types to CK device operation parameters (layouts, data types, elementwise ops, block configurations, etc.)
|
||||
|
||||
## Usage
|
||||
|
||||
```cpp
|
||||
#include "ck_tile/builder/factory/conv_dispatcher.hpp"
|
||||
|
||||
using Factory = decltype(make_conv_instance<signature, algorithm, "v1">());
|
||||
```
|
||||
|
||||
The dispatcher automatically selects the appropriate factory following explicit logic.
|
||||
@@ -0,0 +1,196 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// Compile-time dispatcher for convolution kernel instantiation.
|
||||
//
|
||||
// This header provides a centralized factory dispatch mechanism that routes algorithm
|
||||
// specifications to appropriate convolution kernel implementations at compile-time.
|
||||
//
|
||||
// ## Design Overview
|
||||
//
|
||||
// The dispatcher operates in two phases:
|
||||
// 1. **Algorithm Identification**: Five `consteval` predicate functions (`IsXdlV3Algorithm`,
|
||||
// `IsXdlAlgorithm`, `IsWmmaAlgorithm`, `IsDlAlgorithm`, `IsLargeTensorAlgorithm`) inspect
|
||||
// the algorithm descriptor's structure to determine which kernel variant it satisfies.
|
||||
// Each predicate checks a specific set of concept constraints that define a kernel variant.
|
||||
//
|
||||
// 2. **Factory Routing**: The main `make_conv_instance()` function uses `if constexpr`
|
||||
// to dispatch to the appropriate factory class based on both the convolution direction
|
||||
// and the identified algorithm type. All routing decisions occur at compile-time,
|
||||
// ensuring zero runtime overhead.
|
||||
//
|
||||
// ## Supported Kernel Variants
|
||||
//
|
||||
// - **XDL V3**: Newer XDL-based pipeline using block GEMM structure. Requires fewer parameters
|
||||
// than standard XDL (e.g., uses `SpecifiesBlockGemm` instead of scheduling/prefetch configs).
|
||||
//
|
||||
// - **XDL**: Standard XDL-based kernel using AMD XDLops hardware instructions for matrix
|
||||
// multiply. Requires full scheduling configuration including prefetch stages and loop scheduler.
|
||||
//
|
||||
// - **WMMA**: Wavefront Matrix-Matrix Accumulate variant optimized for WMMA-capable hardware.
|
||||
// Requires similar configuration to XDL.
|
||||
//
|
||||
// - **DL**: Specialized vectorized dot-product kernel optimized for specific data layouts
|
||||
// (NHWC/KYXC/NHWK). The "DL" label just indicates this does not use XDLops instructions.
|
||||
//
|
||||
// - **Large Tensor**: XDL-based kernel with extended tensor support. Wraps a base XDL algorithm
|
||||
// and adds large tensor capabilities.
|
||||
//
|
||||
// ## Current Limitations
|
||||
//
|
||||
// Currently only forward convolution is supported. Backward data and backward weight convolution
|
||||
// directions will fail at compile-time with informative static_assert messages.
|
||||
//
|
||||
// ## Usage Example
|
||||
//
|
||||
// ```
|
||||
// auto kernel = make_conv_instance<SIGNATURE, ALGORITHM>();
|
||||
// ```
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
// Include all factory implementations
|
||||
#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// This dispatch logic is rigid and confusing for users. Further, hides most of
|
||||
// the great error messages from our concepts.
|
||||
//
|
||||
// Requirements for a good design:
|
||||
// 1. Fall through is bad: inputs should get directly to an implementation
|
||||
// if we are going to have good compiler errors.
|
||||
// 2. Logic should be easy for library users to understand.
|
||||
// 3. Logic should be easy to test, maintain, and extend.
|
||||
//
|
||||
// We should probably add explicit tags to the algorithm descriptors, at least
|
||||
// for the initial implemenation.
|
||||
//
|
||||
// To avoid changing behavior too much during refactoring, we leave the explicit
|
||||
// dispatch logic here for now, just changing it from SFINAE to consteval + if constexpr.
|
||||
// There may be some subtle behavior changes, but build failure messages will be more
|
||||
// clear.
|
||||
//
|
||||
// TODO: Make this dispatch logic much more robust and clear for users.
|
||||
|
||||
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
|
||||
template <typename T>
|
||||
consteval bool IsXdlV3Algorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
}
|
||||
|
||||
// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
|
||||
template <typename T>
|
||||
consteval bool IsXdlAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
|
||||
SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
|
||||
// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions)
|
||||
template <typename T>
|
||||
consteval bool IsWmmaAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
|
||||
// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts
|
||||
template <typename T>
|
||||
consteval bool IsDlAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
}
|
||||
|
||||
// XDL-based kernel with large tensor support
|
||||
template <typename T>
|
||||
consteval bool IsLargeTensorAlgorithm()
|
||||
{
|
||||
return IsXdlAlgorithm<decltype(T::base_algorithm)>() && SpecifiesLargeTensorSupport<T>;
|
||||
}
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
constexpr auto make_conv_instance()
|
||||
{
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsXdlAlgorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsWmmaAlgorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsDlAlgorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsLargeTensorAlgorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"No suitable forward convolution kernel factory found for the provided ALGORITHM. "
|
||||
"The ALGORITHM must satisfy requirements for one of: XDL V3, XDL, WMMA, DL (NHWC "
|
||||
"layout), or Large Tensor variant.");
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"Backward data convolution is not yet supported. "
|
||||
"Only forward convolution (ConvDirection::FORWARD) is currently implemented.");
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"Backward weight convolution is not yet supported. "
|
||||
"Only forward convolution (ConvDirection::FORWARD) is currently implemented.");
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false,
|
||||
"Invalid or unsupported convolution direction. "
|
||||
"The SIGNATURE must specify a valid ConvDirection: FORWARD, BACKWARD_DATA, "
|
||||
"or BACKWARD_WEIGHT.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
|
||||
// DL-specific parameters from algorithm descriptor
|
||||
static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config;
|
||||
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
|
||||
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
|
||||
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
|
||||
static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread;
|
||||
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
|
||||
|
||||
// Thread cluster from descriptor
|
||||
static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster;
|
||||
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
|
||||
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
|
||||
|
||||
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer;
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_lengths>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_arrange_order>;
|
||||
using ABlockTransferSrcAccessOrder = to_sequence_v<DL_A_TRANSFER.src_access_order>;
|
||||
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_lengths>;
|
||||
using ABlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer;
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_lengths>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_arrange_order>;
|
||||
using BBlockTransferSrcAccessOrder = to_sequence_v<DL_B_TRANSFER.src_access_order>;
|
||||
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_lengths>;
|
||||
using BBlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// C Thread Transfer from descriptor
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue;
|
||||
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
|
||||
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
|
||||
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
|
||||
DL_C_TRANSFER.dst_scalar_per_vector;
|
||||
|
||||
// The DL forward convolution kernel class instance
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
SPATIAL_DIM,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
FWD_CONV_SPECIALIZATION,
|
||||
GEMM_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,117 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdLargeTensorFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
internal::SetFwdConvSpecialization<BASE_ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<BASE_ALGORITHM>();
|
||||
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<BASE_ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<BASE_ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
internal::SetCBlockTransfer<SIGNATURE, BASE_ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance with large tensor support.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BASE_ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
static_cast<ck::index_t>(A_BLOCK_TRANSFER.lds_padding),
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
static_cast<ck::index_t>(B_BLOCK_TRANSFER.lds_padding),
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
LOOP_SCHEDULER>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,119 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
|
||||
ALGORITHM.transfer.b.lds_transfer.is_direct_load,
|
||||
"A and B block transfers must both be direct load or not.");
|
||||
|
||||
static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer.is_direct_load;
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
static_cast<ck::index_t>(A_BLOCK_TRANSFER.lds_padding),
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
static_cast<ck::index_t>(B_BLOCK_TRANSFER.lds_padding),
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
IS_DIRECT_LOAD>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,113 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
|
||||
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
static_cast<ck::index_t>(A_BLOCK_TRANSFER.lds_padding),
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
static_cast<ck::index_t>(B_BLOCK_TRANSFER.lds_padding),
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
LOOP_SCHEDULER,
|
||||
GRIDWISE_GEMM_PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,114 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
static_cast<ck::index_t>(A_BLOCK_TRANSFER.lds_padding),
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
static_cast<ck::index_t>(B_BLOCK_TRANSFER.lds_padding),
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
LOOP_SCHEDULER,
|
||||
ALGORITHM.num_groups_to_merge>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/array.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Block transfer parameters for A or B tensor.
|
||||
struct BlockTransfer
|
||||
{
|
||||
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
|
||||
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
|
||||
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
bool is_direct_load = false;
|
||||
bool lds_padding = false;
|
||||
};
|
||||
|
||||
template <auto TRANSFER>
|
||||
constexpr BlockTransfer SetFwdConvBlockTransfer()
|
||||
{
|
||||
auto& block_xfer = TRANSFER.block_transfer;
|
||||
auto& block_order = TRANSFER.block_transfer_access_order;
|
||||
auto& src_order = TRANSFER.src_access_order;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
|
||||
return BlockTransfer{
|
||||
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = lds_cfg.is_direct_load,
|
||||
.lds_padding = lds_cfg.lds_padding,
|
||||
};
|
||||
}
|
||||
|
||||
// Block transfer parameters for C tensor.
|
||||
struct CBlockTransfer
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle = 0;
|
||||
size_t n_xdl_per_wave_per_shuffle = 0;
|
||||
ck::Array<size_t, 4> thread_cluster_dims = {0, 0, 0, 0};
|
||||
size_t scalar_per_vector = 0;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr CBlockTransfer SetCBlockTransfer()
|
||||
{
|
||||
auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_cluster_dims;
|
||||
auto& epilogue_config = ALGORITHM.transfer.c.epilogue;
|
||||
return CBlockTransfer{
|
||||
.m_xdl_per_wave_per_shuffle = epilogue_config.m_xdl_per_wave_per_shuffle,
|
||||
.n_xdl_per_wave_per_shuffle = epilogue_config.n_per_wave_per_shuffle,
|
||||
.thread_cluster_dims =
|
||||
{
|
||||
thread_cluster_dims.m_block,
|
||||
thread_cluster_dims.m_wave_per_xdl,
|
||||
thread_cluster_dims.n_block,
|
||||
thread_cluster_dims.n_wave_per_xdl,
|
||||
},
|
||||
.scalar_per_vector = epilogue_config.scalar_per_vector,
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
template <ElementwiseOperation T>
|
||||
struct ElementwiseOps
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported elementwise operation for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
|
||||
{
|
||||
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOps<ElementwiseOperation::SCALE>
|
||||
{
|
||||
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::Scale;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
|
||||
template <auto LayoutValue, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> && ValidConvLayoutForSpatialDim<LayoutValue, SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
// This will trigger if a specialization for the given layout is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
using Layout = decltype(LayoutValue);
|
||||
static_assert(sizeof(Layout) == 0,
|
||||
"Internal error. Unsupported layout for convolution factory.");
|
||||
};
|
||||
|
||||
// 1D Forward Convolution Layout Specializations
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NWGC_GKXC_NWGK, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKXC_NGKW, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::GNWC_GKXC_GNWK, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNWK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKCX_NGKW, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKYXC_NGKHW, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NHWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNHWK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKCYX_NGKHW, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCDHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCZYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKDHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNDHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNDHWK;
|
||||
};
|
||||
|
||||
template <GroupConvLayout Layout, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
consteval auto GetTensorLayout()
|
||||
{
|
||||
|
||||
if constexpr(SPATIAL_DIM == 1)
|
||||
{
|
||||
return internal::ConvTensorLayouts<Layout._1d, 1, DIR>{};
|
||||
}
|
||||
else if constexpr(SPATIAL_DIM == 2)
|
||||
{
|
||||
return internal::ConvTensorLayouts<Layout._2d, 2, DIR>{};
|
||||
}
|
||||
else if constexpr(SPATIAL_DIM == 3)
|
||||
{
|
||||
return internal::ConvTensorLayouts<Layout._3d, 3, DIR>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported spatial dimension for convolution layout.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,87 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Type mappings from builder convolution data type to CK tensor types.
|
||||
template <DataType T>
|
||||
struct ConvTensorTypes
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported data type for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP16>
|
||||
{
|
||||
using ADataType = ck::half_t;
|
||||
using AComputeType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using BComputeType = ck::half_t;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::BF16>
|
||||
{
|
||||
using ADataType = ck::bhalf_t;
|
||||
using AComputeType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using BComputeType = ck::bhalf_t;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::bhalf_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP32>
|
||||
{
|
||||
using ADataType = float;
|
||||
using AComputeType = float;
|
||||
using BDataType = float;
|
||||
using BComputeType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::I8>
|
||||
{
|
||||
using ADataType = int8_t;
|
||||
using AComputeType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using BComputeType = int8_t;
|
||||
using CShuffleDataType = int8_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = int32_t;
|
||||
using EDataType = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP8>
|
||||
{
|
||||
using ADataType = ck::f8_t;
|
||||
using AComputeType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using BComputeType = ck::f8_t;
|
||||
using CShuffleDataType = ck::f8_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::f8_t;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
/// @brief Data tile dimensions processed by a workgroup.
|
||||
/// @details This struct defines the M, N, and K dimensions of the data tile
|
||||
/// that a single workgroup (thread block) is responsible for processing in the
|
||||
/// underlying GEMM computation.
|
||||
struct DataTileInfo
|
||||
{
|
||||
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
|
||||
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
|
||||
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
|
||||
};
|
||||
|
||||
struct ConvBlock
|
||||
{
|
||||
size_t block_size = 0;
|
||||
DataTileInfo per_block = {};
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr ConvBlock SetThreadBlockInfo()
|
||||
{
|
||||
constexpr auto& TB = ALGORITHM.thread_block;
|
||||
return ConvBlock{
|
||||
.block_size = TB.block_size,
|
||||
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k},
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,160 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// The algorithm specializations for the convolution and GEMM.
|
||||
template <typename CONV_ENUM>
|
||||
requires(
|
||||
std::is_same_v<CONV_ENUM, ck::tensor_operation::device::ConvolutionForwardSpecialization>)
|
||||
struct ConvSpec
|
||||
{
|
||||
CONV_ENUM conv_spec;
|
||||
ck::tensor_operation::device::GemmSpecialization gemm_spec;
|
||||
};
|
||||
|
||||
// Deduction guide for ConvSpec to simplify brace initialization.
|
||||
template <typename CONV_ENUM, typename GEMM_ENUM>
|
||||
ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec<CONV_ENUM>;
|
||||
|
||||
struct BlockGemmSpec
|
||||
{
|
||||
ck::BlockGemmPipelineVersion pipeline_version;
|
||||
ck::BlockGemmPipelineScheduler scheduler;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval BlockGemmSpec SetBlockGemm()
|
||||
{
|
||||
constexpr auto& BG = ALGORITHM.block_gemm;
|
||||
|
||||
ck::BlockGemmPipelineScheduler scheduler;
|
||||
ck::BlockGemmPipelineVersion version;
|
||||
|
||||
switch(BG.scheduler)
|
||||
{
|
||||
case PipelineScheduler::INTRAWAVE: scheduler = ck::BlockGemmPipelineScheduler::Intrawave; break;
|
||||
case PipelineScheduler::INTERWAVE: scheduler = ck::BlockGemmPipelineScheduler::Interwave; break;
|
||||
case PipelineScheduler::DEFAULT: throw "Block GEMM scheduler must be Intrawave or Interwave.";
|
||||
default: throw "Unknown PipelineScheduler";
|
||||
}
|
||||
|
||||
switch(BG.pipeline_version)
|
||||
{
|
||||
case PipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break;
|
||||
case PipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break;
|
||||
case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
|
||||
case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
|
||||
case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM.";
|
||||
default: throw "Unknown PipelineVersion";
|
||||
}
|
||||
|
||||
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::LoopScheduler SetLoopScheduler()
|
||||
{
|
||||
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
|
||||
using ck_loop_sched = ck::LoopScheduler;
|
||||
switch(loop_scheduler)
|
||||
{
|
||||
case PipelineScheduler::DEFAULT: return ck_loop_sched::Default;
|
||||
case PipelineScheduler::INTERWAVE: return ck_loop_sched::Interwave;
|
||||
case PipelineScheduler::INTRAWAVE: throw "LoopScheduler must be either DEFAULT or INTERWAVE.";
|
||||
default: throw "Unknown PipelineScheduler";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
|
||||
using ck_pipeline = ck::PipelineVersion;
|
||||
switch(pipeline_version)
|
||||
{
|
||||
case PipelineVersion::V1: return ck_pipeline::v1;
|
||||
case PipelineVersion::V2: return ck_pipeline::v2;
|
||||
case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K.";
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM.";
|
||||
case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
|
||||
default: throw "Unknown GridwiseGemmPipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization()
|
||||
{
|
||||
constexpr auto gemm_spec = ALGORITHM.gemm_specialization;
|
||||
using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
switch(gemm_spec)
|
||||
{
|
||||
case GemmSpecialization::Default: return ck_gemm_spec::Default;
|
||||
case GemmSpecialization::MPadding: return ck_gemm_spec::MPadding;
|
||||
case GemmSpecialization::NPadding: return ck_gemm_spec::NPadding;
|
||||
case GemmSpecialization::KPadding: return ck_gemm_spec::KPadding;
|
||||
case GemmSpecialization::MNPadding: return ck_gemm_spec::MNPadding;
|
||||
case GemmSpecialization::MKPadding: return ck_gemm_spec::MKPadding;
|
||||
case GemmSpecialization::NKPadding: return ck_gemm_spec::NKPadding;
|
||||
case GemmSpecialization::MNKPadding: return ck_gemm_spec::MNKPadding;
|
||||
case GemmSpecialization::OPadding: return ck_gemm_spec::OPadding;
|
||||
case GemmSpecialization::MOPadding: return ck_gemm_spec::MOPadding;
|
||||
case GemmSpecialization::NOPadding: return ck_gemm_spec::NOPadding;
|
||||
case GemmSpecialization::KOPadding: return ck_gemm_spec::KOPadding;
|
||||
case GemmSpecialization::MNOPadding: return ck_gemm_spec::MNOPadding;
|
||||
case GemmSpecialization::MKOPadding: return ck_gemm_spec::MKOPadding;
|
||||
case GemmSpecialization::NKOPadding: return ck_gemm_spec::NKOPadding;
|
||||
case GemmSpecialization::MNKOPadding: return ck_gemm_spec::MNKOPadding;
|
||||
default: throw "Unknown GemmSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.pipeline_version;
|
||||
using ck_pipeline = ck::BlockGemmPipelineVersion;
|
||||
switch(version)
|
||||
{
|
||||
case PipelineVersion::V1: return ck_pipeline::v1;
|
||||
case PipelineVersion::V2: return ck_pipeline::v2;
|
||||
case PipelineVersion::V3: return ck_pipeline::v3;
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: return ck_pipeline::v5;
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.fwd_specialization;
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
|
||||
default: throw "Unknown ConvFwdSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
#include <concepts>
|
||||
#include <ck_tile/builder/conv_builder.hpp>
|
||||
#include <ck_tile/builder/conv_factory.hpp>
|
||||
#include <ck_tile/builder/conv_signature_concepts.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_util.hpp>
|
||||
@@ -681,15 +680,14 @@ struct ConvTraits<Instance>
|
||||
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
|
||||
/// @details This specialization provides backward compatibility for reflecting
|
||||
/// on kernels defined via the `ConvBuilder` interface. It works by first
|
||||
/// creating the `Instance` via the builder's factory, and then delegating
|
||||
/// creating the `Instance` via the builder, and then delegating
|
||||
/// all trait extraction to the `ConvTraits<Instance>` specialization.
|
||||
template <builder::ConvSignatureDescriptor auto SIGNATURE,
|
||||
builder::ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
builder::StringLiteral VERSION>
|
||||
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
|
||||
{
|
||||
using Factory = builder::ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
|
||||
using Instance = typename Factory::Instance;
|
||||
using Instance = typename builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>::Instance;
|
||||
|
||||
// Delegate to Instance-based ConvTraits
|
||||
using InstanceConvTraits = ConvTraits<Instance>;
|
||||
|
||||
Reference in New Issue
Block a user