[CK_BUILDER] Ck Tile Grouped convolution factory (#3352)

* [BUILDER] Ck Tile Grouped convolution factory

* Part 2

* Fixes after rebase

* Remove leftovers
This commit is contained in:
Bartłomiej Kocot
2025-12-08 10:32:56 +01:00
committed by GitHub
parent 8fec8054b2
commit 04612c30ce
55 changed files with 1431 additions and 92 deletions

View File

@@ -95,6 +95,47 @@ concept AccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileThreadBlockDescriptor = requires(T t) {
{ t.tile_size.m } -> std::convertible_to<size_t>;
{ t.tile_size.n } -> std::convertible_to<size_t>;
{ t.tile_size.k } -> std::convertible_to<size_t>;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileTransferDescriptor = requires(T t) {
{ t.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.c_scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
template <typename T>
concept TileBlockGemmDescriptor = requires(T t) {
{ t.warps.m } -> std::convertible_to<int>;
{ t.warps.n } -> std::convertible_to<int>;
{ t.warps.k } -> std::convertible_to<int>;
{ t.warp_tile.m } -> std::convertible_to<int>;
{ t.warp_tile.n } -> std::convertible_to<int>;
{ t.warp_tile.k } -> std::convertible_to<int>;
{ t.double_smem_buffer } -> std::convertible_to<bool>;
{ t.num_wave_groups } -> std::convertible_to<int>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept to check if struct specifies optimizations (CK Tile).
template <typename T>
concept TileOptimizationsDescriptor = requires(T t) {
{ t.num_groups_to_merge } -> std::convertible_to<int>;
{ t.split_image } -> std::convertible_to<bool>;
{ t.explicit_gemm } -> std::convertible_to<bool>;
};
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
// concept.
template <typename T>
@@ -110,6 +151,12 @@ concept SpecifiesThreadBlock = requires {
{ T::thread_block } -> ThreadBlockDescriptor;
};
// Concept to check if struct specifies thread block info (CK Tile).
template <typename T>
concept SpecifiesTileThreadBlock = requires {
{ T::thread_block } -> TileThreadBlockDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseXdlGemm = requires {
@@ -130,6 +177,14 @@ concept SpecifiesBlockTransfer = requires(T t) {
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
};
// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
template <typename T>
concept SpecifiesTileTransfer = requires(T t) {
{ T::transfer.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.c_scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
template <typename T>
concept SpecifiesLdsTransfer = requires(T t) {
@@ -159,8 +214,36 @@ concept SpecifiesBlockGemm = requires {
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
template <typename T>
concept SpecifiesFwdConcSpecialization = requires {
concept SpecifiesTileBlockGemm = requires {
{ T::block_gemm.warps.m } -> std::convertible_to<int>;
{ T::block_gemm.warps.n } -> std::convertible_to<int>;
{ T::block_gemm.warps.k } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.m } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.n } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.k } -> std::convertible_to<int>;
{ T::block_gemm.double_smem_buffer } -> std::convertible_to<bool>;
{ T::block_gemm.num_wave_groups } -> std::convertible_to<int>;
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
template <typename T>
concept SpecifiesTileOptimizations = requires {
{ T::optimizations.num_groups_to_merge } -> std::convertible_to<int>;
{ T::optimizations.split_image } -> std::convertible_to<bool>;
{ T::optimizations.explicit_gemm } -> std::convertible_to<bool>;
};
template <typename T>
concept SpecifiesTileConvSpecialization = requires {
{ T::specialization } -> std::convertible_to<TileConvSpecialization>;
};
template <typename T>
concept SpecifiesFwdConvSpecialization = requires {
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
};

View File

@@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires {
Value.lds_dst_scalar_per_vector > 0;
};
// Limits for input and output vector transfer (CK Tile).
template <auto Value>
concept TileInputOutputVectorTransferLimits =
requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; };
// Limits for output vector transfer.
template <auto Value>
concept OutputVectorTransferLimits = requires {

View File

@@ -59,6 +59,7 @@
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
namespace ck_tile::builder::factory {
@@ -81,6 +82,15 @@ namespace ck_tile::builder::factory {
//
// TODO: Make this dispatch logic much more robust and clear for users.
// CK Tile kernel
template <typename T>
consteval bool IsTileAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> && SpecifiesTileTransfer<T> &&
SpecifiesTileConvSpecialization<T> && SpecifiesTileBlockGemm<T> &&
SpecifiesTileOptimizations<T>;
}
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
template <typename T>
consteval bool IsXdlV3Algorithm()
@@ -88,7 +98,7 @@ consteval bool IsXdlV3Algorithm()
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesBlockGemm<T>;
}
@@ -99,7 +109,7 @@ consteval bool IsXdlAlgorithm()
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
SpecifiesLoopScheduler<T>;
}
@@ -111,7 +121,7 @@ consteval bool IsWmmaAlgorithm()
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
}
@@ -120,7 +130,7 @@ template <typename T>
consteval bool IsDlAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
}
@@ -137,10 +147,15 @@ template <ConvSignatureDescriptor auto SIGNATURE,
StringLiteral VERSION>
constexpr auto make_conv_instance()
{
if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
// CK Tile supports common factory for each direction
if constexpr(IsTileAlgorithm<AlgoType>())
{
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
if constexpr(IsXdlV3Algorithm<AlgoType>())
{
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};

View File

@@ -7,11 +7,11 @@
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -0,0 +1,131 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp"
namespace ck_tile::builder::factory {
// Factory for CK Tile Grouped Convolution kernels.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
struct ConvTileFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::TileConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
using Ops = internal::TileElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto CONV_SPECIALIZATION = internal::SetTileConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetTileThreadBlockInfo<ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm<ALGORITHM>();
static constexpr auto OPTIMIZATIONS = internal::SetTileOptimizations<ALGORITHM>();
static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer<ALGORITHM>();
static constexpr auto CONV_DIRECTION = internal::SetTileConvDirection<SIGNATURE>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(TileInputOutputVectorTransferLimits<SCALAR_PER_VECTOR>);
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<SPATIAL_DIM,
CONV_SPECIALIZATION,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
SCALAR_PER_VECTOR.a,
SCALAR_PER_VECTOR.b,
SCALAR_PER_VECTOR.c,
OPTIMIZATIONS.num_groups_to_merge,
OPTIMIZATIONS.split_image,
OPTIMIZATIONS.explicit_gemm>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k>,
ck_tile::sequence<BLOCK_GEMM.warps.m, BLOCK_GEMM.warps.n, BLOCK_GEMM.warps.k>,
ck_tile::sequence<BLOCK_GEMM.warp_tile.m, BLOCK_GEMM.warp_tile.n, BLOCK_GEMM.warp_tile.k>>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
BLOCK_GEMM.double_smem_buffer,
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::AsLayout,
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::BsLayout,
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::CLayout,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
BLOCK_GEMM.num_wave_groups>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
GemmShape,
GemmUniversalTraits,
BLOCK_GEMM.scheduler,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Types::EDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename internal::TilePipelineType<
BLOCK_GEMM.pipeline_version>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
typename Types::ADataType,
typename Types::BDataType,
typename Types::DsDataTypes,
typename Types::AccDataType,
typename Types::EDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
typename Ops::CDEElementwiseOp,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK_GEMM.warps.m,
BLOCK_GEMM.warps.n,
BLOCK_GEMM.warp_tile.m,
BLOCK_GEMM.warp_tile.n,
BLOCK_GEMM.warp_tile.k,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
// TODO:: This template parameter will be moved inside the kernel
ck_tile::memory_operation_enum::set,
BLOCK_GEMM.num_wave_groups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
SCALAR_PER_VECTOR.c>>;
using Instance = typename internal::GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>::Instance;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,25 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::factory::internal {
struct TileScalarPerVector
{
size_t a = 0;
size_t b = 0;
size_t c = 0;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr TileScalarPerVector SetTileBlockTransfer()
{
return TileScalarPerVector{.a = ALGORITHM.transfer.a_scalar_per_vector,
.b = ALGORITHM.transfer.b_scalar_per_vector,
.c = ALGORITHM.transfer.c_scalar_per_vector};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder::factory::internal {
template <ElementwiseOperation Op>
struct ElementwiseOpToCKTile
{
static_assert(sizeof(UnsupportedEnumValue<Op>) == 0,
"Unsupported elementwise operation conversion to CK.");
};
template <>
struct ElementwiseOpToCKTile<ElementwiseOperation::PASS_THROUGH>
{
using Op = ck_tile::element_wise::PassThrough;
};
template <>
struct ElementwiseOpToCKTile<ElementwiseOperation::SCALE>
{
using Op = ck_tile::element_wise::Scale;
};
template <>
struct ElementwiseOpToCKTile<ElementwiseOperation::CLAMP>
{
using Op = ck_tile::element_wise::Clamp;
};
template <auto TensorDesc>
consteval auto GetTileElementwiseOp()
{
if constexpr(HasTensorOp<decltype(TensorDesc)>)
{
constexpr auto op = TensorDesc.operation.elementwise_operation;
return ElementwiseOpToCKTile<op>{};
}
else
{
return ElementwiseOpToCKTile<ElementwiseOperation::PASS_THROUGH>{};
}
}
template <auto Sig>
struct TileElementwiseOps
{
static constexpr auto input_op = GetTileElementwiseOp<Sig.input>();
static constexpr auto weight_op = GetTileElementwiseOp<Sig.weight>();
static constexpr auto output_op = GetTileElementwiseOp<Sig.output>();
using AElementwiseOp = typename decltype(input_op)::Op;
using BElementwiseOp = typename decltype(weight_op)::Op;
using CDEElementwiseOp = typename decltype(output_op)::Op;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,88 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_signature_concepts.hpp"
namespace ck_tile::builder::factory::internal {
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
struct GroupedConvolutionTileKernel
{
static_assert(false, "Unknown Direction");
};
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
requires ConvDirectionIsForward<SIGNATURE>
struct GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>
{
using Instance = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
};
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
requires ConvDirectionIsBackwardData<SIGNATURE>
struct GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>
{
using Instance = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
};
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>
{
using Instance = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
};
template <ConvSignatureDescriptor auto SIGNATURE>
consteval ck_tile::GroupedConvDirection SetTileConvDirection()
{
constexpr auto direction = SIGNATURE.direction;
using ck_tile_direction = ck_tile::GroupedConvDirection;
switch(direction)
{
case ConvDirection::FORWARD: return ck_tile_direction::FORWARD;
case ConvDirection::BACKWARD_DATA: return ck_tile_direction::BACKWARD_DATA;
case ConvDirection::BACKWARD_WEIGHT: return ck_tile_direction::BACKWARD_WEIGHT;
default: throw "Unknown Direction";
}
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,200 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
namespace ck_tile::builder::factory::internal {
using ALayout = ck_tile::tensor_layout::convolution::NWGC;
template <TensorLayout Layout>
struct LayoutToCKTile
{
static_assert(sizeof(UnsupportedEnumValue<Layout>) == 0,
"Unsupported layout conversion to CK.");
};
// Bias layouts
template <>
struct LayoutToCKTile<TensorLayout::G_K_strided>
{
using type = ck_tile::tensor_layout::convolution::G_K;
};
template <>
struct LayoutToCKTile<TensorLayout::GC>
{
using type = ck_tile::tensor_layout::convolution::GC;
};
template <>
struct LayoutToCKTile<TensorLayout::G_C_strided>
{
using type = ck_tile::tensor_layout::convolution::G_C;
};
// Input 1D
template <>
struct LayoutToCKTile<TensorLayout::NWGC>
{
using type = ck_tile::tensor_layout::convolution::NWGC;
};
template <>
struct LayoutToCKTile<TensorLayout::GNWC>
{
using type = ck_tile::tensor_layout::convolution::GNWC;
};
// Input 2D
template <>
struct LayoutToCKTile<TensorLayout::NHWGC>
{
using type = ck_tile::tensor_layout::convolution::NHWGC;
};
template <>
struct LayoutToCKTile<TensorLayout::GNHWC>
{
using type = ck_tile::tensor_layout::convolution::GNHWC;
};
// Input 3D
template <>
struct LayoutToCKTile<TensorLayout::NDHWGC>
{
using type = ck_tile::tensor_layout::convolution::NDHWGC;
};
template <>
struct LayoutToCKTile<TensorLayout::GNDHWC>
{
using type = ck_tile::tensor_layout::convolution::GNDHWC;
};
// Weight 1D
template <>
struct LayoutToCKTile<TensorLayout::GKXC>
{
using type = ck_tile::tensor_layout::convolution::GKXC;
};
template <>
struct LayoutToCKTile<TensorLayout::GKCX>
{
using type = ck_tile::tensor_layout::convolution::GKCX;
};
// Weight 2D
template <>
struct LayoutToCKTile<TensorLayout::GKYXC>
{
using type = ck_tile::tensor_layout::convolution::GKYXC;
};
template <>
struct LayoutToCKTile<TensorLayout::GKCYX>
{
using type = ck_tile::tensor_layout::convolution::GKCYX;
};
// Weight 3D
template <>
struct LayoutToCKTile<TensorLayout::GKCZYX>
{
using type = ck_tile::tensor_layout::convolution::GKCZYX;
};
template <>
struct LayoutToCKTile<TensorLayout::GKZYXC>
{
using type = ck_tile::tensor_layout::convolution::GKZYXC;
};
// Output 1D
template <>
struct LayoutToCKTile<TensorLayout::NWGK>
{
using type = ck_tile::tensor_layout::convolution::NWGK;
};
template <>
struct LayoutToCKTile<TensorLayout::GNWK>
{
using type = ck_tile::tensor_layout::convolution::GNWK;
};
// Output 2D
template <>
struct LayoutToCKTile<TensorLayout::NHWGK>
{
using type = ck_tile::tensor_layout::convolution::NHWGK;
};
template <>
struct LayoutToCKTile<TensorLayout::GNHWK>
{
using type = ck_tile::tensor_layout::convolution::GNHWK;
};
// Output 3D
template <>
struct LayoutToCKTile<TensorLayout::NDHWGK>
{
using type = ck_tile::tensor_layout::convolution::NDHWGK;
};
template <>
struct LayoutToCKTile<TensorLayout::GNDHWK>
{
using type = ck_tile::tensor_layout::convolution::GNDHWK;
};
template <TensorLayout Layout>
consteval auto TensorLayoutToCKTile()
{
return typename LayoutToCKTile<Layout>::type{};
}
struct EmptyAuxiliaryTileTensorLayout
{
using type = ck_tile::tuple<>;
};
template <auto AuxiliaryTileTensorConfigsArray, size_t... Indices>
consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence<Indices...>)
{
return ck_tile::tuple<
decltype(TensorLayoutToCKTile<AuxiliaryTileTensorConfigsArray[Indices].layout>())...>{};
}
template <auto AuxiliaryTileTensorConfigsValue, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM>)
struct AuxiliaryTileTensorLayouts
{
static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size();
using type = decltype(GetAuxiliaryTileTensorLayoutTuple<AuxiliaryTileTensorConfigsValue>(
std::make_index_sequence<Size>{}));
};
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
template <auto Signature, size_t SPATIAL_DIM>
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
consteval auto GetAuxiliaryTileTensorLayouts()
{
return AuxiliaryTileTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
SPATIAL_DIM>{};
}
template <auto Signature, size_t SPATIAL_DIM>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
consteval auto GetAuxiliaryTileTensorLayouts()
{
return EmptyAuxiliaryTileTensorLayout{};
}
template <auto Signature, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM> &&
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
struct TileConvTensorLayouts
{
using ALayout = decltype(TensorLayoutToCKTile<Signature.input.config.layout>());
using BLayout = decltype(TensorLayoutToCKTile<Signature.weight.config.layout>());
using ELayout = decltype(TensorLayoutToCKTile<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTileTensorLayouts<Signature, SPATIAL_DIM>())::type;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,87 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/builder_utils.hpp"
namespace ck_tile::builder::factory::internal {
// Type mappings from builder convolution data type to CK Tile tensor types.
template <DataType T>
struct TileConvTensorTypes
{
// This will trigger if a specialization for the given DataType is not found.
// We should always catch this in an earlier validation check.
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
"Internal error. Unsupported data type for convolution factory.");
};
template <>
struct TileConvTensorTypes<DataType::FP16>
{
using ADataType = ck_tile::half_t;
using AComputeType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using BComputeType = ck_tile::half_t;
using CShuffleDataType = ck_tile::half_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::half_t;
};
template <>
struct TileConvTensorTypes<DataType::BF16>
{
using ADataType = ck_tile::bf16_t;
using AComputeType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using BComputeType = ck_tile::bf16_t;
using CShuffleDataType = ck_tile::bf16_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::bf16_t;
};
template <>
struct TileConvTensorTypes<DataType::FP32>
{
using ADataType = float;
using AComputeType = float;
using BDataType = float;
using BComputeType = float;
using CShuffleDataType = float;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = float;
};
template <>
struct TileConvTensorTypes<DataType::I8>
{
using ADataType = int8_t;
using AComputeType = int8_t;
using BDataType = int8_t;
using BComputeType = int8_t;
using CShuffleDataType = int8_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = int32_t;
using EDataType = int8_t;
};
template <>
struct TileConvTensorTypes<DataType::FP8>
{
using ADataType = ck_tile::fp8_t;
using AComputeType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using BComputeType = ck_tile::fp8_t;
using CShuffleDataType = ck_tile::fp8_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::fp8_t;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,32 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::factory::internal {
// Convenience struct for a tuple of m, n, and k values.
struct TileBlockMNK
{
int m{};
int n{};
int k{};
};
struct TileConvBlock
{
TileBlockMNK per_block = {};
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr TileConvBlock SetTileThreadBlockInfo()
{
constexpr auto& TB = ALGORITHM.thread_block;
return TileConvBlock{
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k},
};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,158 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder::factory::internal {
// Convenience struct for a tuple of m, n, and k values.
struct TileBlockGemmMNK
{
int m{};
int n{};
int k{};
};
struct TileBlockGemmSpec
{
TileBlockGemmMNK warps = {};
TileBlockGemmMNK warp_tile = {};
bool double_smem_buffer = false;
int num_wave_groups = 1;
ck_tile::GemmPipeline pipeline_version;
ck_tile::GemmPipelineScheduler scheduler;
};
struct TileOptimizations
{
int num_groups_to_merge = 1;
bool split_image = false;
bool explicit_gemm = false;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::GemmPipelineScheduler SetTileScheduler()
{
constexpr auto scheduler = ALGORITHM.block_gemm.scheduler;
using ck_tile_sched = ck_tile::GemmPipelineScheduler;
switch(scheduler)
{
case PipelineScheduler::DEFAULT: return ck_tile_sched::Default;
case PipelineScheduler::INTERWAVE: return ck_tile_sched::Interwave;
case PipelineScheduler::INTRAWAVE: return ck_tile_sched::Intrawave;
default: throw "Unknown PipelineScheduler";
}
}
template <ck_tile::GemmPipeline PipelineId>
struct TilePipelineType
{
static_assert(false, "Unknown PipelineScheduler");
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::BASIC_V1>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.block_gemm.pipeline_version;
using ck_tile_pipeline = ck_tile::GemmPipeline;
switch(version)
{
case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1;
case PipelineVersion::V2: return ck_tile_pipeline::MEMORY;
case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3;
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
default: throw "Unknown block GEMM PipelineVersion";
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::ConvolutionSpecialization SetTileConvSpecialization()
{
constexpr auto specialization = ALGORITHM.specialization;
using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization;
switch(specialization)
{
case TileConvSpecialization::DEFAULT: return ck_tile_conv_spec::Default;
case TileConvSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0;
case TileConvSpecialization::FILTER_1X1_STRIDE1_PAD0:
return ck_tile_conv_spec::Filter1x1Stride1Pad0;
case TileConvSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3;
default: throw "Unknown ConvFwdSpecialization";
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval TileBlockGemmSpec SetTileBlockGemm()
{
constexpr auto& BG = ALGORITHM.block_gemm;
constexpr bool double_smem_buffer = BG.double_smem_buffer;
constexpr int num_wave_groups = BG.num_wave_groups;
constexpr ck_tile::GemmPipeline pipeline_version = SetTileBlockGemmPipelineVersion<ALGORITHM>();
constexpr ck_tile::GemmPipelineScheduler scheduler = SetTileScheduler<ALGORITHM>();
return TileBlockGemmSpec{
.warps = {.m = BG.warps.m, .n = BG.warps.n, .k = BG.warps.k},
.warp_tile = {.m = BG.warp_tile.m, .n = BG.warp_tile.n, .k = BG.warp_tile.k},
.double_smem_buffer = double_smem_buffer,
.num_wave_groups = num_wave_groups,
.pipeline_version = pipeline_version,
.scheduler = scheduler};
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval TileOptimizations SetTileOptimizations()
{
constexpr auto& OPT = ALGORITHM.optimizations;
return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge,
.split_image = OPT.split_image,
.explicit_gemm = OPT.explicit_gemm};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -145,6 +145,15 @@ enum struct GemmSpecialization
MNKOPadding
};
// Enums for the CK Tile convolution specialization.
enum class TileConvSpecialization
{
DEFAULT,
FILTER_1X1_PAD0,
FILTER_1X1_STRIDE1_PAD0,
FILTER_3x3
};
// Enums for the forward convolution specialization.
enum class ConvFwdSpecialization
{