[BUILDER] Ck Tile Grouped convolution factory

This commit is contained in:
Bartlomiej Kocot
2025-12-04 09:38:17 +00:00
parent 08bd4decf3
commit f032bc3f9a
43 changed files with 885 additions and 48 deletions

View File

@@ -95,6 +95,39 @@ concept AccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileThreadBlockDescriptor = requires(T t) {
{ t.tile_size.m } -> std::convertible_to<size_t>;
{ t.tile_size.n } -> std::convertible_to<size_t>;
{ t.tile_size.k } -> std::convertible_to<size_t>;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileTransferDescriptor = requires(T t) {
{ t.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.c_scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
template <typename T>
concept TileBlockGemmDescriptor = requires(T t) {
{ t.warp.m } -> std::convertible_to<int>;
{ t.warp.n } -> std::convertible_to<int>;
{ t.warp.k } -> std::convertible_to<int>;
{ t.warp_tile.m } -> std::convertible_to<int>;
{ t.warp_tile.n } -> std::convertible_to<int>;
{ t.warp_tile.k } -> std::convertible_to<int>;
{ t.double_smem_buffer } -> std::convertible_to<bool>;
{ t.num_wave_groups } -> std::convertible_to<int>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
// concept.
template <typename T>
@@ -110,6 +143,12 @@ concept SpecifiesThreadBlock = requires {
{ T::thread_block } -> ThreadBlockDescriptor;
};
// Concept to check if struct specifies thread block info (CK Tile).
template <typename T>
concept SpecifiesTileThreadBlock = requires {
{ T::thread_block } -> TileThreadBlockDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseXdlGemm = requires {
@@ -130,6 +169,14 @@ concept SpecifiesBlockTransfer = requires(T t) {
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
};
// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
template <typename T>
concept SpecifiesTileTransfer = requires(T t) {
{ T::transfer.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.c_scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
template <typename T>
concept SpecifiesLdsTransfer = requires(T t) {
@@ -159,6 +206,21 @@ concept SpecifiesBlockGemm = requires {
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
template <typename T>
concept SpecifiesTileBlockGemm = requires {
{ T::block_gemm.warps.m } -> std::convertible_to<int>;
{ T::block_gemm.warps.n } -> std::convertible_to<int>;
{ T::block_gemm.warps.k } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.m } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.n } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.k } -> std::convertible_to<int>;
{ T::block_gemm.double_smem_buffer } -> std::convertible_to<bool>;
{ T::block_gemm.num_wave_groups } -> std::convertible_to<int>;
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
template <typename T>
concept SpecifiesFwdConcSpecialization = requires {
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;

View File

@@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires {
Value.lds_dst_scalar_per_vector > 0;
};
// Limits for input and output vector transfer (CK Tile).
template <auto Value>
concept TileInputOutputVectorTransferLimits =
requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; };
// Limits for output vector transfer.
template <auto Value>
concept OutputVectorTransferLimits = requires {

View File

@@ -59,6 +59,7 @@
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_tile_factory.hpp"
namespace ck_tile::builder::factory {
@@ -81,6 +82,14 @@ namespace ck_tile::builder::factory {
//
// TODO: Make this dispatch logic much more robust and clear for users.
// CK Tile kernel
template <typename T>
consteval bool IsTileAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> && SpecifiesTileTransfer<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesTileBlockGemm<T>;
}
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
template <typename T>
consteval bool IsXdlV3Algorithm()
@@ -141,7 +150,11 @@ constexpr auto make_conv_instance()
{
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
if constexpr(IsXdlV3Algorithm<AlgoType>())
if constexpr(IsTileAlgorithm<AlgoType>())
{
return typename ConvFwdTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsXdlV3Algorithm<AlgoType>())
{
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}

View File

@@ -8,11 +8,11 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/conv_signature_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -9,12 +9,12 @@
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/conv_signature_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -0,0 +1,142 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/conv_signature_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
// of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE>
struct ConvFwdTileFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(internal::GetTileTensorLayout<SIGNATURE.layout,
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION =
internal::SetTileFwdConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetTileThreadBlockInfo<ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm<ALGORITHM>();
static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer<ALGORITHM.transfer>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(TileInputOutputVectorTransferLimits<SCALAR_PER_VECTOR>);
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<SPATIAL_DIM,
FWD_CONV_SPECIALIZATION,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
SCALAR_PER_VECTOR.a,
SCALAR_PER_VECTOR.b,
SCALAR_PER_VECTOR.c,
ALGORITHM.num_groups_to_merge,
ALGORITHM.split_image,
ALGORITHM.explicit_gemm>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k>,
ck_tile::sequence<BLOCK_GEMM.warps.m, BLOCK_GEMM.warps.n, BLOCK_GEMM.warps.k>,
ck_tile::sequence<BLOCK_GEMM.warp_tile.m, BLOCK_GEMM.warp_tile.n, BLOCK_GEMM.warp_tile.k>>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
BLOCK_GEMM.double_smem_buffer,
typename GroupedConvTraitsType::AsLayoutFwd,
typename GroupedConvTraitsType::BsLayoutFwd,
typename GroupedConvTraitsType::CLayoutFwd,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
BLOCK_GEMM.num_wave_groups>;
template <bool has_hot_loop, ck_tile::TailNumber tail_number>
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
GemmShape,
GemmUniversalTraits,
BLOCK_GEMM.loop_scheduler,
has_hot_loop,
tail_number,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Types::EDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
template <bool has_hot_loop, ck_tile::TailNumber tail_number>
using GemmPipeline =
typename internal::TilePipelineType<BLOCK_GEMM.pipeline_version>::template GemmPipeline<
UniversalGemmProblem<has_hot_loop, tail_number>>;
template <ck_tile::memory_operation_enum memory_operation>
using ConvEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename Types::ADataType,
typename Types::BDataType,
typename Types::DsDataTypes,
typename Types::AccDataType,
typename Types::EDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
typename Ops::CDEElementwiseOp,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK_GEMM.warps.m,
BLOCK_GEMM.warps.n,
BLOCK_GEMM.warp_tile.m,
BLOCK_GEMM.warp_tile.n,
BLOCK_GEMM.warp_tile.k,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
BLOCK_GEMM.num_wave_groups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
SCALAR_PER_VECTOR.c>>;
template <bool has_hot_loop,
ck_tile::TailNumber tail_number,
ck_tile::memory_operation_enum memory_operation>
using Instance =
ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline<has_hot_loop, tail_number>,
ConvEpilogue<memory_operation>>;
};
} // namespace ck_tile::builder::factory

View File

@@ -9,12 +9,12 @@
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/conv_signature_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -9,12 +9,12 @@
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/conv_signature_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -9,12 +9,12 @@
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/conv_signature_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -0,0 +1,25 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::factory::internal {
struct TileScalarPerVector
{
size_t a = 0;
size_t b = 0;
size_t c = 0;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr TileScalarPerVector SetTileBlockTransfer()
{
return TileScalarPerVector{.a = ALGORITHM.transfer.a.scalar_per_vector,
.b = ALGORITHM.transfer.b.scalar_per_vector,
.c = ALGORITHM.transfer.c.scalar_per_vector};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder::factory::internal {
template <ElementwiseOperation T>
struct TileElementwiseOps
{
// This will trigger if a specialization for the given DataType is not found.
// We should always catch this in an earlier validation check.
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
"Internal error. Unsupported elementwise operation for convolution factory.");
};
template <>
struct TileElementwiseOps<ElementwiseOperation::PASS_THROUGH>
{
using AElementwiseOp = ck_tile::element_wise::PassThrough;
using BElementwiseOp = ck_tile::element_wise::PassThrough;
using CDEElementwiseOp = ck_tile::element_wise::PassThrough;
};
template <>
struct TileElementwiseOps<ElementwiseOperation::SCALE>
{
using AElementwiseOp = ck_tile::element_wise::PassThrough;
using BElementwiseOp = ck_tile::element_wise::PassThrough;
using CDEElementwiseOp = ck_tile::element_wise::Scale;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,101 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
namespace ck_tile::builder::factory::internal {
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK Tile tensor data types.
template <auto LayoutValue, size_t SPATIAL_DIM, ConvDirection DIR>
requires(ConvSpatialDim<SPATIAL_DIM> && ValidConvLayoutForSpatialDim<LayoutValue, SPATIAL_DIM>)
struct TileConvTensorLayouts
{
// This will trigger if a specialization for the given layout is not found.
// We should always catch this in an earlier validation check.
using Layout = decltype(LayoutValue);
static_assert(sizeof(Layout) == 0,
"Internal error. Unsupported layout for convolution factory.");
};
// 1D Forward Convolution Layout Specializations
template <>
struct TileConvTensorLayouts<GroupConvLayout1D::NWGC_GKXC_NWGK, 1, ConvDirection::FORWARD>
{
using ALayout = ck_tile::tensor_layout::convolution::NWGC;
using BLayout = ck_tile::tensor_layout::convolution::GKXC;
using DsLayout = ck_tile::tuple<>;
using ELayout = ck_tile::tensor_layout::convolution::NWGK;
};
template <>
struct TileConvTensorLayouts<GroupConvLayout1D::GNWC_GKXC_GNWK, 1, ConvDirection::FORWARD>
{
using ALayout = ck_tile::tensor_layout::convolution::GNWC;
using BLayout = ck_tile::tensor_layout::convolution::GKXC;
using DsLayout = ck_tile::tuple<>;
using ELayout = ck_tile::tensor_layout::convolution::GNWK;
};
template <>
struct TileConvTensorLayouts<GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, ConvDirection::FORWARD>
{
using ALayout = ck_tile::tensor_layout::convolution::NHWGC;
using BLayout = ck_tile::tensor_layout::convolution::GKYXC;
using DsLayout = ck_tile::tuple<>;
using ELayout = ck_tile::tensor_layout::convolution::NHWGK;
};
template <>
struct TileConvTensorLayouts<GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, ConvDirection::FORWARD>
{
using ALayout = ck_tile::tensor_layout::convolution::GNHWC;
using BLayout = ck_tile::tensor_layout::convolution::GKYXC;
using DsLayout = ck_tile::tuple<>;
using ELayout = ck_tile::tensor_layout::convolution::GNHWK;
};
template <>
struct TileConvTensorLayouts<GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, ConvDirection::FORWARD>
{
using ALayout = ck_tile::tensor_layout::convolution::NDHWGC;
using BLayout = ck_tile::tensor_layout::convolution::GKZYXC;
using DsLayout = ck_tile::tuple<>;
using ELayout = ck_tile::tensor_layout::convolution::NDHWGK;
};
template <>
struct TileConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirection::FORWARD>
{
using ALayout = ck_tile::tensor_layout::convolution::GNDHWC;
using BLayout = ck_tile::tensor_layout::convolution::GKZYXC;
using DsLayout = ck_tile::tuple<>;
using ELayout = ck_tile::tensor_layout::convolution::GNDHWK;
};
template <GroupConvLayout Layout, size_t SPATIAL_DIM, ConvDirection DIR>
consteval auto GetTileTensorLayout()
{
if constexpr(SPATIAL_DIM == 1)
{
return internal::TileConvTensorLayouts<Layout._1d, 1, DIR>{};
}
else if constexpr(SPATIAL_DIM == 2)
{
return internal::TileConvTensorLayouts<Layout._2d, 2, DIR>{};
}
else if constexpr(SPATIAL_DIM == 3)
{
return internal::TileConvTensorLayouts<Layout._3d, 3, DIR>{};
}
else
{
static_assert(false, "Unsupported spatial dimension for convolution layout.");
}
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,87 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/builder_utils.hpp"
namespace ck_tile::builder::factory::internal {
// Type mappings from builder convolution data type to CK Tile tensor types.
template <DataType T>
struct TileConvTensorTypes
{
// This will trigger if a specialization for the given DataType is not found.
// We should always catch this in an earlier validation check.
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
"Internal error. Unsupported data type for convolution factory.");
};
template <>
struct TileConvTensorTypes<DataType::FP16>
{
using ADataType = ck_tile::half_t;
using AComputeType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using BComputeType = ck_tile::half_t;
using CShuffleDataType = ck_tile::half_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::half_t;
};
template <>
struct TileConvTensorTypes<DataType::BF16>
{
using ADataType = ck_tile::bf16_t;
using AComputeType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using BComputeType = ck_tile::bf16_t;
using CShuffleDataType = ck_tile::bf16_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::bf16_t;
};
template <>
struct TileConvTensorTypes<DataType::FP32>
{
using ADataType = float;
using AComputeType = float;
using BDataType = float;
using BComputeType = float;
using CShuffleDataType = float;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = float;
};
template <>
struct TileConvTensorTypes<DataType::I8>
{
using ADataType = int8_t;
using AComputeType = int8_t;
using BDataType = int8_t;
using BComputeType = int8_t;
using CShuffleDataType = int8_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = int32_t;
using EDataType = int8_t;
};
template <>
struct TileConvTensorTypes<DataType::FP8>
{
using ADataType = ck_tile::fp8_t;
using AComputeType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using BComputeType = ck_tile::fp8_t;
using CShuffleDataType = ck_tile::fp8_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::fp8_t;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,32 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::factory::internal {
// Convenience struct for a tuple of m, n, and k values.
struct TileBlockMNK
{
int m{};
int n{};
int k{};
};
struct TileConvBlock
{
TileBlockMNK per_block = {};
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr TileConvBlock SetTileThreadBlockInfo()
{
constexpr auto& TB = ALGORITHM.thread_block;
return TileConvBlock{
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k},
};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,137 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder::factory::internal {
// Convenience struct for a tuple of m, n, and k values.
struct TileBlockGemmMNK
{
int m{};
int n{};
int k{};
};
struct TileBlockGemmSpec
{
TileBlockGemmMNK warps = {};
TileBlockGemmMNK warp_tile = {};
bool double_smem_buffer = false;
int num_wave_groups = 1;
ck_tile::GemmPipelineScheduler pipeline_version;
ck_tile::GemmPipeline loop_scheduler;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::GemmPipelineScheduler SetTileLoopScheduler()
{
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
using ck_tile_loop_sched = ck_tile::GemmPipelineScheduler;
switch(loop_scheduler)
{
case PipelineScheduler::DEFAULT: return ck_tile_loop_sched::Default;
case PipelineScheduler::INTERWAVE: return ck_tile_loop_sched::Interwave;
case PipelineScheduler::INTRAWAVE: return ck_tile_loop_sched::Intrawave;
default: throw "Unknown PipelineScheduler";
}
}
template <ck_tile::GemmPipeline PipelineId>
struct TilePipelineType
{
static_assert(false, "Unknown PipelineScheduler");
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.pipeline_version;
using ck_tile_pipeline = ck_tile::GemmPipeline;
switch(version)
{
case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1;
case PipelineVersion::V2: return ck_tile_pipeline::MEMORY;
case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3;
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
default: throw "Unknown block GEMM PipelineVersion";
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::ConvolutionSpecialization SetTileFwdConvSpecialization()
{
constexpr auto specialization = ALGORITHM.fwd_specialization;
using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization;
switch(specialization)
{
case ConvFwdSpecialization::DEFAULT: return ck_tile_conv_spec::Default;
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0;
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0:
return ck_tile_conv_spec::Filter1x1Stride1Pad0;
case ConvFwdSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3;
default: throw "Unknown ConvFwdSpecialization";
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval TileBlockGemmSpec SetTileBlockGemm()
{
constexpr auto& BG = ALGORITHM.block_gemm;
constexpr TileBlockGemmMNK warps = BG.warps;
constexpr TileBlockGemmMNK warp_tile = BG.warp_tile;
constexpr bool double_smem_buffer = BG.double_smem_buffer;
constexpr int num_wave_groups = BG.num_wave_groups;
constexpr ck_tile::GemmPipelineScheduler pipeline_version =
SetTileBlockGemmPipelineVersion<ALGORITHM>();
constexpr ck_tile::GemmPipeline loop_scheduler = SetTileLoopScheduler<ALGORITHM>();
return TileBlockGemmSpec{.warps = warps,
.warp_tile = warp_tile,
.double_smem_buffer = double_smem_buffer,
.num_wave_groups = num_wave_groups,
.pipeline_version = pipeline_version,
.loop_scheduler = loop_scheduler};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -90,7 +90,7 @@ add_ck_builder_test(test_ckb_conv_builder
# Tests convolution trait selection and configuration
add_ck_builder_test(test_ckb_conv_traits
conv/test_conv_traits.cpp)
conv/ck/test_conv_traits.cpp)
# Tests convolution problem description and parameter handling
add_ck_builder_test(test_ckb_conv_description
@@ -119,18 +119,19 @@ add_ck_builder_test(test_ckb_get_instance_string
# Tests the forward convolution builder across multiple data types and dimensions.
# Individual tests are split into separate files to enable parallel compilation.
add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_1d_fp16.cpp
conv/test_ckb_conv_fwd_1d_bf16.cpp
conv/test_ckb_conv_fwd_1d_i8.cpp
conv/test_ckb_conv_fwd_2d_fp8.cpp
conv/test_ckb_conv_fwd_2d_bf16.cpp
conv/test_ckb_conv_fwd_2d_fp16.cpp
conv/test_ckb_conv_fwd_2d_fp32.cpp
conv/test_ckb_conv_fwd_2d_dl_fp16.cpp
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
conv/test_ckb_conv_fwd_3d_bf16.cpp
conv/test_ckb_conv_fwd_3d_fp16.cpp
conv/test_ckb_conv_fwd_3d_fp32.cpp
conv/ck/test_ckb_conv_fwd_1d_fp16.cpp
conv/ck/test_ckb_conv_fwd_1d_bf16.cpp
conv/ck/test_ckb_conv_fwd_1d_i8.cpp
conv/ck/test_ckb_conv_fwd_2d_fp8.cpp
conv/ck/test_ckb_conv_fwd_2d_bf16.cpp
conv/ck/test_ckb_conv_fwd_2d_fp16.cpp
conv/ck/test_ckb_conv_fwd_2d_fp32.cpp
conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp
conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
conv/ck/test_ckb_conv_fwd_3d_bf16.cpp
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
conv/ck_tile/test_ckb_conv_fwd_2d_fp16.cpp
)

View File

@@ -0,0 +1,34 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_tile_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionForwardKernel_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::FP16,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v1_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 256, 256, 32",
"Filter1x1Pad0",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v3"});
}
} // namespace

View File

@@ -243,6 +243,52 @@ struct LargeTensorWrapper
ConvAlgorithmSpecialization::LARGE_TENSOR;
};
// Specify thread block dimensions for a GEMM (CK Tile).
struct TileThreadBlock
{
// Size of the submatrix problem in a thread block.
MNK<size_t> tile_size;
};
static_assert(ckb::TileThreadBlockDescriptor<TileThreadBlock>);
struct TileTransfer
{
size_t a_scalar_per_vector;
size_t b_scalar_per_vector;
size_t c_scalar_per_vector;
};
static_assert(ckb::TileTransferDescriptor<TileTransfer>);
struct TileBlockGemm
{
// Number of warps per each dimension.
MNK<int> warp;
// Number of data processed per each dimension for each XDL/WMMA instruction.
MNK<int> warp_tile;
// Double LDS buffer.
bool double_smem_buffer;
// Waves grouping (Ping-Pong scheduler).
int num_wave_groups;
PipelineVersion pipeline_version;
PipelineScheduler scheduler;
};
static_assert(ckb::TileBlockGemmDescriptor<TileBlockGemm>);
struct TileThreadBlock_
{
TileThreadBlock thread_block;
};
struct TileTransfer_
{
TileTransfer transfer;
};
struct TileBlockGemm_
{
TileBlockGemm block_gemm;
};
// Factory
template <typename... Components>
@@ -339,6 +385,33 @@ struct ConvAlgorithmTemplate : Components...
result.transfer = t;
return result;
}
template <typename TB>
constexpr auto with_tile_thread_block(const TB& tb) const
{
static_assert(std::is_base_of_v<TileThreadBlock_, ConvAlgorithmTemplate>);
auto result = *this;
result.thread_block = tb;
return result;
}
template <typename BG>
constexpr auto with_tile_block_gemm(const BG& bg) const
{
static_assert(std::is_base_of_v<TileBlockGemm_, ConvAlgorithmTemplate>);
auto result = *this;
result.block_gemm = bg;
return result;
}
template <typename T>
constexpr auto with_tile_transfer(const T& t) const
{
static_assert(std::is_base_of_v<TileTransfer_, ConvAlgorithmTemplate>);
auto result = *this;
result.transfer = t;
return result;
}
};
// Algorithm types
@@ -361,4 +434,7 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
using ConvAlgorithm_Tile_GroupedConvolutionForwardKernel =
ConvAlgorithmTemplate<TileThreadBlock_, TileBlockGemm_, TileTransfer_, ConvSpecialization_>;
} // namespace ck_tile::builder::test

View File

@@ -4,7 +4,7 @@
#include <gtest/gtest.h>
#include <type_traits>
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
namespace {

View File

@@ -5,7 +5,7 @@
#include <type_traits>
// Include the helper file we're testing
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
namespace {

View File

@@ -4,7 +4,7 @@
#include <gtest/gtest.h>
#include <type_traits>
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
namespace {

View File

@@ -2,7 +2,7 @@
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace {

View File

@@ -3,7 +3,7 @@
#include <gtest/gtest.h>
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
namespace {

View File

@@ -0,0 +1,85 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "impl/conv_algorithm_types.hpp"
#include "impl/conv_signature_types.hpp"
#include "ck_tile/builder/conv_builder.hpp"
namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
constexpr TileTransfer FwdTileTransfer_1x1x1{
.a_scalar_per_vector = 1,
.b_scalar_per_vector = 1,
.c_scalar_per_vector = 1,
};
constexpr TileTransfer FwdTileTransfer_4x4x4{
.a_scalar_per_vector = 4,
.b_scalar_per_vector = 4,
.c_scalar_per_vector = 4,
};
constexpr TileTransfer FwdTileTransfer_8x8x8{
.a_scalar_per_vector = 8,
.b_scalar_per_vector = 8,
.c_scalar_per_vector = 8,
};
constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = {
.warp = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V1,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v2_intrawave = {
.warp = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V2,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v3_intrawave = {
.warp = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V3,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v4_intrawave = {
.warp = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V4,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v5_intrawave = {
.warp = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V5,
.scheduler = PipelineScheduler::INTRAWAVE};
} // namespace ck_tile::builder::test_utils