mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
[BUILDER] Ck Tile Grouped convolution factory
This commit is contained in:
@@ -95,6 +95,39 @@ concept AccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
|
||||
// size is deduced from block gemm structure).
|
||||
template <typename T>
|
||||
concept TileThreadBlockDescriptor = requires(T t) {
|
||||
{ t.tile_size.m } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.n } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.k } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
|
||||
// size is deduced from block gemm structure).
|
||||
template <typename T>
|
||||
concept TileTransferDescriptor = requires(T t) {
|
||||
{ t.a_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.b_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.c_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
template <typename T>
|
||||
concept TileBlockGemmDescriptor = requires(T t) {
|
||||
{ t.warp.m } -> std::convertible_to<int>;
|
||||
{ t.warp.n } -> std::convertible_to<int>;
|
||||
{ t.warp.k } -> std::convertible_to<int>;
|
||||
{ t.warp_tile.m } -> std::convertible_to<int>;
|
||||
{ t.warp_tile.n } -> std::convertible_to<int>;
|
||||
{ t.warp_tile.k } -> std::convertible_to<int>;
|
||||
{ t.double_smem_buffer } -> std::convertible_to<bool>;
|
||||
{ t.num_wave_groups } -> std::convertible_to<int>;
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
|
||||
// concept.
|
||||
template <typename T>
|
||||
@@ -110,6 +143,12 @@ concept SpecifiesThreadBlock = requires {
|
||||
{ T::thread_block } -> ThreadBlockDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies thread block info (CK Tile).
|
||||
template <typename T>
|
||||
concept SpecifiesTileThreadBlock = requires {
|
||||
{ T::thread_block } -> TileThreadBlockDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseXdlGemm = requires {
|
||||
@@ -130,6 +169,14 @@ concept SpecifiesBlockTransfer = requires(T t) {
|
||||
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
|
||||
template <typename T>
|
||||
concept SpecifiesTileTransfer = requires(T t) {
|
||||
{ T::transfer.a_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.b_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.c_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
|
||||
template <typename T>
|
||||
concept SpecifiesLdsTransfer = requires(T t) {
|
||||
@@ -159,6 +206,21 @@ concept SpecifiesBlockGemm = requires {
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
template <typename T>
|
||||
concept SpecifiesTileBlockGemm = requires {
|
||||
{ T::block_gemm.warps.m } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warps.n } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warps.k } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warp_tile.m } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warp_tile.n } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warp_tile.k } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.double_smem_buffer } -> std::convertible_to<bool>;
|
||||
{ T::block_gemm.num_wave_groups } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesFwdConcSpecialization = requires {
|
||||
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
|
||||
|
||||
@@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires {
|
||||
Value.lds_dst_scalar_per_vector > 0;
|
||||
};
|
||||
|
||||
// Limits for input and output vector transfer (CK Tile).
|
||||
template <auto Value>
|
||||
concept TileInputOutputVectorTransferLimits =
|
||||
requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; };
|
||||
|
||||
// Limits for output vector transfer.
|
||||
template <auto Value>
|
||||
concept OutputVectorTransferLimits = requires {
|
||||
|
||||
@@ -59,6 +59,7 @@
|
||||
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_tile_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
@@ -81,6 +82,14 @@ namespace ck_tile::builder::factory {
|
||||
//
|
||||
// TODO: Make this dispatch logic much more robust and clear for users.
|
||||
|
||||
// CK Tile kernel
|
||||
template <typename T>
|
||||
consteval bool IsTileAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> && SpecifiesTileTransfer<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesTileBlockGemm<T>;
|
||||
}
|
||||
|
||||
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
|
||||
template <typename T>
|
||||
consteval bool IsXdlV3Algorithm()
|
||||
@@ -141,7 +150,11 @@ constexpr auto make_conv_instance()
|
||||
{
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>())
|
||||
if constexpr(IsTileAlgorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsXdlV3Algorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
|
||||
@@ -8,11 +8,11 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -9,12 +9,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdTileFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(internal::GetTileTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
internal::SetTileFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetTileThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm<ALGORITHM>();
|
||||
static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer<ALGORITHM.transfer>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(TileInputOutputVectorTransferLimits<SCALAR_PER_VECTOR>);
|
||||
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<SPATIAL_DIM,
|
||||
FWD_CONV_SPECIALIZATION,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
SCALAR_PER_VECTOR.a,
|
||||
SCALAR_PER_VECTOR.b,
|
||||
SCALAR_PER_VECTOR.c,
|
||||
ALGORITHM.num_groups_to_merge,
|
||||
ALGORITHM.split_image,
|
||||
ALGORITHM.explicit_gemm>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k>,
|
||||
ck_tile::sequence<BLOCK_GEMM.warps.m, BLOCK_GEMM.warps.n, BLOCK_GEMM.warps.k>,
|
||||
ck_tile::sequence<BLOCK_GEMM.warp_tile.m, BLOCK_GEMM.warp_tile.n, BLOCK_GEMM.warp_tile.k>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
BLOCK_GEMM.double_smem_buffer,
|
||||
typename GroupedConvTraitsType::AsLayoutFwd,
|
||||
typename GroupedConvTraitsType::BsLayoutFwd,
|
||||
typename GroupedConvTraitsType::CLayoutFwd,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
BLOCK_GEMM.num_wave_groups>;
|
||||
|
||||
template <bool has_hot_loop, ck_tile::TailNumber tail_number>
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
BLOCK_GEMM.loop_scheduler,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Types::EDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
template <bool has_hot_loop, ck_tile::TailNumber tail_number>
|
||||
using GemmPipeline =
|
||||
typename internal::TilePipelineType<BLOCK_GEMM.pipeline_version>::template GemmPipeline<
|
||||
UniversalGemmProblem<has_hot_loop, tail_number>>;
|
||||
|
||||
template <ck_tile::memory_operation_enum memory_operation>
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::AccDataType,
|
||||
typename Types::EDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK_GEMM.warps.m,
|
||||
BLOCK_GEMM.warps.n,
|
||||
BLOCK_GEMM.warp_tile.m,
|
||||
BLOCK_GEMM.warp_tile.n,
|
||||
BLOCK_GEMM.warp_tile.k,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
BLOCK_GEMM.num_wave_groups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
SCALAR_PER_VECTOR.c>>;
|
||||
|
||||
template <bool has_hot_loop,
|
||||
ck_tile::TailNumber tail_number,
|
||||
ck_tile::memory_operation_enum memory_operation>
|
||||
using Instance =
|
||||
ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline<has_hot_loop, tail_number>,
|
||||
ConvEpilogue<memory_operation>>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -9,12 +9,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -9,12 +9,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -9,12 +9,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
struct TileScalarPerVector
|
||||
{
|
||||
size_t a = 0;
|
||||
size_t b = 0;
|
||||
size_t c = 0;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr TileScalarPerVector SetTileBlockTransfer()
|
||||
{
|
||||
return TileScalarPerVector{.a = ALGORITHM.transfer.a.scalar_per_vector,
|
||||
.b = ALGORITHM.transfer.b.scalar_per_vector,
|
||||
.c = ALGORITHM.transfer.c.scalar_per_vector};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
template <ElementwiseOperation T>
|
||||
struct TileElementwiseOps
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported elementwise operation for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileElementwiseOps<ElementwiseOperation::PASS_THROUGH>
|
||||
{
|
||||
using AElementwiseOp = ck_tile::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck_tile::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck_tile::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileElementwiseOps<ElementwiseOperation::SCALE>
|
||||
{
|
||||
using AElementwiseOp = ck_tile::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck_tile::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck_tile::element_wise::Scale;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,101 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK Tile tensor data types.
|
||||
template <auto LayoutValue, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> && ValidConvLayoutForSpatialDim<LayoutValue, SPATIAL_DIM>)
|
||||
struct TileConvTensorLayouts
|
||||
{
|
||||
// This will trigger if a specialization for the given layout is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
using Layout = decltype(LayoutValue);
|
||||
static_assert(sizeof(Layout) == 0,
|
||||
"Internal error. Unsupported layout for convolution factory.");
|
||||
};
|
||||
|
||||
// 1D Forward Convolution Layout Specializations
|
||||
template <>
|
||||
struct TileConvTensorLayouts<GroupConvLayout1D::NWGC_GKXC_NWGK, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using BLayout = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using ELayout = ck_tile::tensor_layout::convolution::NWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorLayouts<GroupConvLayout1D::GNWC_GKXC_GNWK, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck_tile::tensor_layout::convolution::GNWC;
|
||||
using BLayout = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using ELayout = ck_tile::tensor_layout::convolution::GNWK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorLayouts<GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using BLayout = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using ELayout = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorLayouts<GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck_tile::tensor_layout::convolution::GNHWC;
|
||||
using BLayout = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using ELayout = ck_tile::tensor_layout::convolution::GNHWK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorLayouts<GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
using BLayout = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using ELayout = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck_tile::tensor_layout::convolution::GNDHWC;
|
||||
using BLayout = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using ELayout = ck_tile::tensor_layout::convolution::GNDHWK;
|
||||
};
|
||||
|
||||
template <GroupConvLayout Layout, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
consteval auto GetTileTensorLayout()
|
||||
{
|
||||
|
||||
if constexpr(SPATIAL_DIM == 1)
|
||||
{
|
||||
return internal::TileConvTensorLayouts<Layout._1d, 1, DIR>{};
|
||||
}
|
||||
else if constexpr(SPATIAL_DIM == 2)
|
||||
{
|
||||
return internal::TileConvTensorLayouts<Layout._2d, 2, DIR>{};
|
||||
}
|
||||
else if constexpr(SPATIAL_DIM == 3)
|
||||
{
|
||||
return internal::TileConvTensorLayouts<Layout._3d, 3, DIR>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported spatial dimension for convolution layout.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,87 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Type mappings from builder convolution data type to CK Tile tensor types.
|
||||
template <DataType T>
|
||||
struct TileConvTensorTypes
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported data type for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::FP16>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using AComputeType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using BComputeType = ck_tile::half_t;
|
||||
using CShuffleDataType = ck_tile::half_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::BF16>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using AComputeType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using BComputeType = ck_tile::bf16_t;
|
||||
using CShuffleDataType = ck_tile::bf16_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::FP32>
|
||||
{
|
||||
using ADataType = float;
|
||||
using AComputeType = float;
|
||||
using BDataType = float;
|
||||
using BComputeType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::I8>
|
||||
{
|
||||
using ADataType = int8_t;
|
||||
using AComputeType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using BComputeType = int8_t;
|
||||
using CShuffleDataType = int8_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = int32_t;
|
||||
using EDataType = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::FP8>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using AComputeType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using BComputeType = ck_tile::fp8_t;
|
||||
using CShuffleDataType = ck_tile::fp8_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck_tile::fp8_t;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
struct TileBlockMNK
|
||||
{
|
||||
int m{};
|
||||
int n{};
|
||||
int k{};
|
||||
};
|
||||
|
||||
struct TileConvBlock
|
||||
{
|
||||
TileBlockMNK per_block = {};
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr TileConvBlock SetTileThreadBlockInfo()
|
||||
{
|
||||
constexpr auto& TB = ALGORITHM.thread_block;
|
||||
return TileConvBlock{
|
||||
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k},
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,137 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
struct TileBlockGemmMNK
|
||||
{
|
||||
int m{};
|
||||
int n{};
|
||||
int k{};
|
||||
};
|
||||
|
||||
struct TileBlockGemmSpec
|
||||
{
|
||||
TileBlockGemmMNK warps = {};
|
||||
TileBlockGemmMNK warp_tile = {};
|
||||
|
||||
bool double_smem_buffer = false;
|
||||
int num_wave_groups = 1;
|
||||
|
||||
ck_tile::GemmPipelineScheduler pipeline_version;
|
||||
ck_tile::GemmPipeline loop_scheduler;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::GemmPipelineScheduler SetTileLoopScheduler()
|
||||
{
|
||||
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
|
||||
using ck_tile_loop_sched = ck_tile::GemmPipelineScheduler;
|
||||
switch(loop_scheduler)
|
||||
{
|
||||
case PipelineScheduler::DEFAULT: return ck_tile_loop_sched::Default;
|
||||
case PipelineScheduler::INTERWAVE: return ck_tile_loop_sched::Interwave;
|
||||
case PipelineScheduler::INTRAWAVE: return ck_tile_loop_sched::Intrawave;
|
||||
default: throw "Unknown PipelineScheduler";
|
||||
}
|
||||
}
|
||||
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct TilePipelineType
|
||||
{
|
||||
static_assert(false, "Unknown PipelineScheduler");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.pipeline_version;
|
||||
using ck_tile_pipeline = ck_tile::GemmPipeline;
|
||||
switch(version)
|
||||
{
|
||||
case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1;
|
||||
case PipelineVersion::V2: return ck_tile_pipeline::MEMORY;
|
||||
case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3;
|
||||
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
|
||||
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::ConvolutionSpecialization SetTileFwdConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.fwd_specialization;
|
||||
using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case ConvFwdSpecialization::DEFAULT: return ck_tile_conv_spec::Default;
|
||||
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0:
|
||||
return ck_tile_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3;
|
||||
default: throw "Unknown ConvFwdSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval TileBlockGemmSpec SetTileBlockGemm()
|
||||
{
|
||||
constexpr auto& BG = ALGORITHM.block_gemm;
|
||||
|
||||
constexpr TileBlockGemmMNK warps = BG.warps;
|
||||
constexpr TileBlockGemmMNK warp_tile = BG.warp_tile;
|
||||
|
||||
constexpr bool double_smem_buffer = BG.double_smem_buffer;
|
||||
constexpr int num_wave_groups = BG.num_wave_groups;
|
||||
|
||||
constexpr ck_tile::GemmPipelineScheduler pipeline_version =
|
||||
SetTileBlockGemmPipelineVersion<ALGORITHM>();
|
||||
constexpr ck_tile::GemmPipeline loop_scheduler = SetTileLoopScheduler<ALGORITHM>();
|
||||
|
||||
return TileBlockGemmSpec{.warps = warps,
|
||||
.warp_tile = warp_tile,
|
||||
.double_smem_buffer = double_smem_buffer,
|
||||
.num_wave_groups = num_wave_groups,
|
||||
.pipeline_version = pipeline_version,
|
||||
.loop_scheduler = loop_scheduler};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -90,7 +90,7 @@ add_ck_builder_test(test_ckb_conv_builder
|
||||
|
||||
# Tests convolution trait selection and configuration
|
||||
add_ck_builder_test(test_ckb_conv_traits
|
||||
conv/test_conv_traits.cpp)
|
||||
conv/ck/test_conv_traits.cpp)
|
||||
|
||||
# Tests convolution problem description and parameter handling
|
||||
add_ck_builder_test(test_ckb_conv_description
|
||||
@@ -119,18 +119,19 @@ add_ck_builder_test(test_ckb_get_instance_string
|
||||
# Tests the forward convolution builder across multiple data types and dimensions.
|
||||
# Individual tests are split into separate files to enable parallel compilation.
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_1d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_i8.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp8.cpp
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/test_ckb_conv_fwd_2d_dl_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
conv/ck/test_ckb_conv_fwd_1d_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_1d_bf16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_1d_i8.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_fp8.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
conv/ck_tile/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_tile_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionForwardKernel_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v1_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v3"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -243,6 +243,52 @@ struct LargeTensorWrapper
|
||||
ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
|
||||
// Specify thread block dimensions for a GEMM (CK Tile).
|
||||
struct TileThreadBlock
|
||||
{
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<size_t> tile_size;
|
||||
};
|
||||
static_assert(ckb::TileThreadBlockDescriptor<TileThreadBlock>);
|
||||
|
||||
struct TileTransfer
|
||||
{
|
||||
size_t a_scalar_per_vector;
|
||||
size_t b_scalar_per_vector;
|
||||
size_t c_scalar_per_vector;
|
||||
};
|
||||
static_assert(ckb::TileTransferDescriptor<TileTransfer>);
|
||||
|
||||
struct TileBlockGemm
|
||||
{
|
||||
// Number of warps per each dimension.
|
||||
MNK<int> warp;
|
||||
// Number of data processed per each dimension for each XDL/WMMA instruction.
|
||||
MNK<int> warp_tile;
|
||||
// Double LDS buffer.
|
||||
bool double_smem_buffer;
|
||||
// Waves grouping (Ping-Pong scheduler).
|
||||
int num_wave_groups;
|
||||
PipelineVersion pipeline_version;
|
||||
PipelineScheduler scheduler;
|
||||
};
|
||||
static_assert(ckb::TileBlockGemmDescriptor<TileBlockGemm>);
|
||||
|
||||
struct TileThreadBlock_
|
||||
{
|
||||
TileThreadBlock thread_block;
|
||||
};
|
||||
|
||||
struct TileTransfer_
|
||||
{
|
||||
TileTransfer transfer;
|
||||
};
|
||||
|
||||
struct TileBlockGemm_
|
||||
{
|
||||
TileBlockGemm block_gemm;
|
||||
};
|
||||
|
||||
// Factory
|
||||
|
||||
template <typename... Components>
|
||||
@@ -339,6 +385,33 @@ struct ConvAlgorithmTemplate : Components...
|
||||
result.transfer = t;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename TB>
|
||||
constexpr auto with_tile_thread_block(const TB& tb) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileThreadBlock_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.thread_block = tb;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BG>
|
||||
constexpr auto with_tile_block_gemm(const BG& bg) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileBlockGemm_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_gemm = bg;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr auto with_tile_transfer(const T& t) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileTransfer_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.transfer = t;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Algorithm types
|
||||
@@ -361,4 +434,7 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
|
||||
|
||||
using ConvAlgorithm_Tile_GroupedConvolutionForwardKernel =
|
||||
ConvAlgorithmTemplate<TileThreadBlock_, TileBlockGemm_, TileTransfer_, ConvSpecialization_>;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <type_traits>
|
||||
|
||||
// Include the helper file we're testing
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_1x1x1{
|
||||
.a_scalar_per_vector = 1,
|
||||
.b_scalar_per_vector = 1,
|
||||
.c_scalar_per_vector = 1,
|
||||
};
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_4x4x4{
|
||||
.a_scalar_per_vector = 4,
|
||||
.b_scalar_per_vector = 4,
|
||||
.c_scalar_per_vector = 4,
|
||||
};
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_8x8x8{
|
||||
.a_scalar_per_vector = 8,
|
||||
.b_scalar_per_vector = 8,
|
||||
.c_scalar_per_vector = 8,
|
||||
};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = {
|
||||
.warp = {.m = 2, .n = 2, .k = 1},
|
||||
.warp_tile = {.m = 16, .n = 16, .k = 16},
|
||||
.double_smem_buffer = false,
|
||||
.num_wave_groups = 1,
|
||||
.pipeline_version = PipelineVersion::V1,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v2_intrawave = {
|
||||
.warp = {.m = 2, .n = 2, .k = 1},
|
||||
.warp_tile = {.m = 16, .n = 16, .k = 16},
|
||||
.double_smem_buffer = false,
|
||||
.num_wave_groups = 1,
|
||||
.pipeline_version = PipelineVersion::V2,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v3_intrawave = {
|
||||
.warp = {.m = 2, .n = 2, .k = 1},
|
||||
.warp_tile = {.m = 16, .n = 16, .k = 16},
|
||||
.double_smem_buffer = false,
|
||||
.num_wave_groups = 1,
|
||||
.pipeline_version = PipelineVersion::V3,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v4_intrawave = {
|
||||
.warp = {.m = 2, .n = 2, .k = 1},
|
||||
.warp_tile = {.m = 16, .n = 16, .k = 16},
|
||||
.double_smem_buffer = false,
|
||||
.num_wave_groups = 1,
|
||||
.pipeline_version = PipelineVersion::V4,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v5_intrawave = {
|
||||
.warp = {.m = 2, .n = 2, .k = 1},
|
||||
.warp_tile = {.m = 16, .n = 16, .k = 16},
|
||||
.double_smem_buffer = false,
|
||||
.num_wave_groups = 1,
|
||||
.pipeline_version = PipelineVersion::V5,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
Reference in New Issue
Block a user