mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_BUILDER] Ck Tile Grouped convolution factory (#3352)
* [BUILDER] Ck Tile Grouped convolution factory * Part 2 * Fixes after rebase * Remove leftovers
This commit is contained in:
@@ -95,6 +95,47 @@ concept AccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
|
||||
// size is deduced from block gemm structure).
|
||||
template <typename T>
|
||||
concept TileThreadBlockDescriptor = requires(T t) {
|
||||
{ t.tile_size.m } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.n } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.k } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
|
||||
// size is deduced from block gemm structure).
|
||||
template <typename T>
|
||||
concept TileTransferDescriptor = requires(T t) {
|
||||
{ t.a_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.b_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.c_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
template <typename T>
|
||||
concept TileBlockGemmDescriptor = requires(T t) {
|
||||
{ t.warps.m } -> std::convertible_to<int>;
|
||||
{ t.warps.n } -> std::convertible_to<int>;
|
||||
{ t.warps.k } -> std::convertible_to<int>;
|
||||
{ t.warp_tile.m } -> std::convertible_to<int>;
|
||||
{ t.warp_tile.n } -> std::convertible_to<int>;
|
||||
{ t.warp_tile.k } -> std::convertible_to<int>;
|
||||
{ t.double_smem_buffer } -> std::convertible_to<bool>;
|
||||
{ t.num_wave_groups } -> std::convertible_to<int>;
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies optimizations (CK Tile).
|
||||
template <typename T>
|
||||
concept TileOptimizationsDescriptor = requires(T t) {
|
||||
{ t.num_groups_to_merge } -> std::convertible_to<int>;
|
||||
{ t.split_image } -> std::convertible_to<bool>;
|
||||
{ t.explicit_gemm } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
|
||||
// concept.
|
||||
template <typename T>
|
||||
@@ -110,6 +151,12 @@ concept SpecifiesThreadBlock = requires {
|
||||
{ T::thread_block } -> ThreadBlockDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies thread block info (CK Tile).
|
||||
template <typename T>
|
||||
concept SpecifiesTileThreadBlock = requires {
|
||||
{ T::thread_block } -> TileThreadBlockDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseXdlGemm = requires {
|
||||
@@ -130,6 +177,14 @@ concept SpecifiesBlockTransfer = requires(T t) {
|
||||
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
|
||||
template <typename T>
|
||||
concept SpecifiesTileTransfer = requires(T t) {
|
||||
{ T::transfer.a_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.b_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.c_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
|
||||
template <typename T>
|
||||
concept SpecifiesLdsTransfer = requires(T t) {
|
||||
@@ -159,8 +214,36 @@ concept SpecifiesBlockGemm = requires {
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
template <typename T>
|
||||
concept SpecifiesFwdConcSpecialization = requires {
|
||||
concept SpecifiesTileBlockGemm = requires {
|
||||
{ T::block_gemm.warps.m } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warps.n } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warps.k } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warp_tile.m } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warp_tile.n } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warp_tile.k } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.double_smem_buffer } -> std::convertible_to<bool>;
|
||||
{ T::block_gemm.num_wave_groups } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
template <typename T>
|
||||
concept SpecifiesTileOptimizations = requires {
|
||||
{ T::optimizations.num_groups_to_merge } -> std::convertible_to<int>;
|
||||
{ T::optimizations.split_image } -> std::convertible_to<bool>;
|
||||
{ T::optimizations.explicit_gemm } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTileConvSpecialization = requires {
|
||||
{ T::specialization } -> std::convertible_to<TileConvSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesFwdConvSpecialization = requires {
|
||||
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
|
||||
};
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires {
|
||||
Value.lds_dst_scalar_per_vector > 0;
|
||||
};
|
||||
|
||||
// Limits for input and output vector transfer (CK Tile).
|
||||
template <auto Value>
|
||||
concept TileInputOutputVectorTransferLimits =
|
||||
requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; };
|
||||
|
||||
// Limits for output vector transfer.
|
||||
template <auto Value>
|
||||
concept OutputVectorTransferLimits = requires {
|
||||
|
||||
@@ -59,6 +59,7 @@
|
||||
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
@@ -81,6 +82,15 @@ namespace ck_tile::builder::factory {
|
||||
//
|
||||
// TODO: Make this dispatch logic much more robust and clear for users.
|
||||
|
||||
// CK Tile kernel
|
||||
template <typename T>
|
||||
consteval bool IsTileAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> && SpecifiesTileTransfer<T> &&
|
||||
SpecifiesTileConvSpecialization<T> && SpecifiesTileBlockGemm<T> &&
|
||||
SpecifiesTileOptimizations<T>;
|
||||
}
|
||||
|
||||
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
|
||||
template <typename T>
|
||||
consteval bool IsXdlV3Algorithm()
|
||||
@@ -88,7 +98,7 @@ consteval bool IsXdlV3Algorithm()
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
}
|
||||
|
||||
@@ -99,7 +109,7 @@ consteval bool IsXdlAlgorithm()
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
|
||||
SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
@@ -111,7 +121,7 @@ consteval bool IsWmmaAlgorithm()
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
|
||||
@@ -120,7 +130,7 @@ template <typename T>
|
||||
consteval bool IsDlAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
}
|
||||
@@ -137,10 +147,15 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
StringLiteral VERSION>
|
||||
constexpr auto make_conv_instance()
|
||||
{
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
|
||||
// CK Tile supports common factory for each direction
|
||||
if constexpr(IsTileAlgorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for CK Tile Grouped Convolution kernels.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
struct ConvTileFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::TileConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::TileElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto CONV_SPECIALIZATION = internal::SetTileConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetTileThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm<ALGORITHM>();
|
||||
static constexpr auto OPTIMIZATIONS = internal::SetTileOptimizations<ALGORITHM>();
|
||||
static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer<ALGORITHM>();
|
||||
static constexpr auto CONV_DIRECTION = internal::SetTileConvDirection<SIGNATURE>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(TileInputOutputVectorTransferLimits<SCALAR_PER_VECTOR>);
|
||||
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<SPATIAL_DIM,
|
||||
CONV_SPECIALIZATION,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
SCALAR_PER_VECTOR.a,
|
||||
SCALAR_PER_VECTOR.b,
|
||||
SCALAR_PER_VECTOR.c,
|
||||
OPTIMIZATIONS.num_groups_to_merge,
|
||||
OPTIMIZATIONS.split_image,
|
||||
OPTIMIZATIONS.explicit_gemm>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k>,
|
||||
ck_tile::sequence<BLOCK_GEMM.warps.m, BLOCK_GEMM.warps.n, BLOCK_GEMM.warps.k>,
|
||||
ck_tile::sequence<BLOCK_GEMM.warp_tile.m, BLOCK_GEMM.warp_tile.n, BLOCK_GEMM.warp_tile.k>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
BLOCK_GEMM.double_smem_buffer,
|
||||
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::AsLayout,
|
||||
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::BsLayout,
|
||||
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::CLayout,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
BLOCK_GEMM.num_wave_groups>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
BLOCK_GEMM.scheduler,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Types::EDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename internal::TilePipelineType<
|
||||
BLOCK_GEMM.pipeline_version>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::AccDataType,
|
||||
typename Types::EDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK_GEMM.warps.m,
|
||||
BLOCK_GEMM.warps.n,
|
||||
BLOCK_GEMM.warp_tile.m,
|
||||
BLOCK_GEMM.warp_tile.n,
|
||||
BLOCK_GEMM.warp_tile.k,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
// TODO:: This template parameter will be moved inside the kernel
|
||||
ck_tile::memory_operation_enum::set,
|
||||
BLOCK_GEMM.num_wave_groups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
SCALAR_PER_VECTOR.c>>;
|
||||
|
||||
using Instance = typename internal::GroupedConvolutionTileKernel<SIGNATURE,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>::Instance;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,25 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
struct TileScalarPerVector
|
||||
{
|
||||
size_t a = 0;
|
||||
size_t b = 0;
|
||||
size_t c = 0;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr TileScalarPerVector SetTileBlockTransfer()
|
||||
{
|
||||
return TileScalarPerVector{.a = ALGORITHM.transfer.a_scalar_per_vector,
|
||||
.b = ALGORITHM.transfer.b_scalar_per_vector,
|
||||
.c = ALGORITHM.transfer.c_scalar_per_vector};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,62 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
template <ElementwiseOperation Op>
|
||||
struct ElementwiseOpToCKTile
|
||||
{
|
||||
static_assert(sizeof(UnsupportedEnumValue<Op>) == 0,
|
||||
"Unsupported elementwise operation conversion to CK.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCKTile<ElementwiseOperation::PASS_THROUGH>
|
||||
{
|
||||
using Op = ck_tile::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCKTile<ElementwiseOperation::SCALE>
|
||||
{
|
||||
using Op = ck_tile::element_wise::Scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCKTile<ElementwiseOperation::CLAMP>
|
||||
{
|
||||
using Op = ck_tile::element_wise::Clamp;
|
||||
};
|
||||
|
||||
template <auto TensorDesc>
|
||||
consteval auto GetTileElementwiseOp()
|
||||
{
|
||||
if constexpr(HasTensorOp<decltype(TensorDesc)>)
|
||||
{
|
||||
constexpr auto op = TensorDesc.operation.elementwise_operation;
|
||||
return ElementwiseOpToCKTile<op>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOpToCKTile<ElementwiseOperation::PASS_THROUGH>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
struct TileElementwiseOps
|
||||
{
|
||||
static constexpr auto input_op = GetTileElementwiseOp<Sig.input>();
|
||||
static constexpr auto weight_op = GetTileElementwiseOp<Sig.weight>();
|
||||
static constexpr auto output_op = GetTileElementwiseOp<Sig.output>();
|
||||
using AElementwiseOp = typename decltype(input_op)::Op;
|
||||
using BElementwiseOp = typename decltype(weight_op)::Op;
|
||||
using CDEElementwiseOp = typename decltype(output_op)::Op;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
typename GroupedConvTraitsType,
|
||||
typename TilePartitioner,
|
||||
typename GemmPipeline,
|
||||
typename ConvEpilogue>
|
||||
struct GroupedConvolutionTileKernel
|
||||
{
|
||||
static_assert(false, "Unknown Direction");
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
typename GroupedConvTraitsType,
|
||||
typename TilePartitioner,
|
||||
typename GemmPipeline,
|
||||
typename ConvEpilogue>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct GroupedConvolutionTileKernel<SIGNATURE,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>
|
||||
{
|
||||
using Instance = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
typename GroupedConvTraitsType,
|
||||
typename TilePartitioner,
|
||||
typename GemmPipeline,
|
||||
typename ConvEpilogue>
|
||||
requires ConvDirectionIsBackwardData<SIGNATURE>
|
||||
struct GroupedConvolutionTileKernel<SIGNATURE,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>
|
||||
{
|
||||
using Instance = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
typename GroupedConvTraitsType,
|
||||
typename TilePartitioner,
|
||||
typename GemmPipeline,
|
||||
typename ConvEpilogue>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct GroupedConvolutionTileKernel<SIGNATURE,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>
|
||||
{
|
||||
using Instance = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE>
|
||||
consteval ck_tile::GroupedConvDirection SetTileConvDirection()
|
||||
{
|
||||
constexpr auto direction = SIGNATURE.direction;
|
||||
using ck_tile_direction = ck_tile::GroupedConvDirection;
|
||||
switch(direction)
|
||||
{
|
||||
case ConvDirection::FORWARD: return ck_tile_direction::FORWARD;
|
||||
case ConvDirection::BACKWARD_DATA: return ck_tile_direction::BACKWARD_DATA;
|
||||
case ConvDirection::BACKWARD_WEIGHT: return ck_tile_direction::BACKWARD_WEIGHT;
|
||||
default: throw "Unknown Direction";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,200 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
using ALayout = ck_tile::tensor_layout::convolution::NWGC;
|
||||
template <TensorLayout Layout>
|
||||
struct LayoutToCKTile
|
||||
{
|
||||
static_assert(sizeof(UnsupportedEnumValue<Layout>) == 0,
|
||||
"Unsupported layout conversion to CK.");
|
||||
};
|
||||
|
||||
// Bias layouts
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::G_K_strided>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::G_K;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::G_C_strided>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::G_C;
|
||||
};
|
||||
|
||||
// Input 1D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NWGC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NWGC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNWC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNWC;
|
||||
};
|
||||
|
||||
// Input 2D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NHWGC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNHWC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNHWC;
|
||||
};
|
||||
|
||||
// Input 3D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NDHWGC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNDHWC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNDHWC;
|
||||
};
|
||||
|
||||
// Weight 1D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKXC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKXC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKCX>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKCX;
|
||||
};
|
||||
|
||||
// Weight 2D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKYXC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKCYX>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKCYX;
|
||||
};
|
||||
|
||||
// Weight 3D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKCZYX>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKCZYX;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKZYXC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
};
|
||||
|
||||
// Output 1D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NWGK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNWK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNWK;
|
||||
};
|
||||
|
||||
// Output 2D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NHWGK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNHWK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNHWK;
|
||||
};
|
||||
|
||||
// Output 3D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NDHWGK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNDHWK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNDHWK;
|
||||
};
|
||||
|
||||
template <TensorLayout Layout>
|
||||
consteval auto TensorLayoutToCKTile()
|
||||
{
|
||||
return typename LayoutToCKTile<Layout>::type{};
|
||||
}
|
||||
|
||||
struct EmptyAuxiliaryTileTensorLayout
|
||||
{
|
||||
using type = ck_tile::tuple<>;
|
||||
};
|
||||
|
||||
template <auto AuxiliaryTileTensorConfigsArray, size_t... Indices>
|
||||
consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence<Indices...>)
|
||||
{
|
||||
return ck_tile::tuple<
|
||||
decltype(TensorLayoutToCKTile<AuxiliaryTileTensorConfigsArray[Indices].layout>())...>{};
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTileTensorConfigsValue, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM>)
|
||||
struct AuxiliaryTileTensorLayouts
|
||||
{
|
||||
static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size();
|
||||
using type = decltype(GetAuxiliaryTileTensorLayoutTuple<AuxiliaryTileTensorConfigsValue>(
|
||||
std::make_index_sequence<Size>{}));
|
||||
};
|
||||
|
||||
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTileTensorLayouts()
|
||||
{
|
||||
return AuxiliaryTileTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
|
||||
SPATIAL_DIM>{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTileTensorLayouts()
|
||||
{
|
||||
return EmptyAuxiliaryTileTensorLayout{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> &&
|
||||
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
|
||||
struct TileConvTensorLayouts
|
||||
{
|
||||
using ALayout = decltype(TensorLayoutToCKTile<Signature.input.config.layout>());
|
||||
using BLayout = decltype(TensorLayoutToCKTile<Signature.weight.config.layout>());
|
||||
using ELayout = decltype(TensorLayoutToCKTile<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTileTensorLayouts<Signature, SPATIAL_DIM>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,87 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Type mappings from builder convolution data type to CK Tile tensor types.
|
||||
template <DataType T>
|
||||
struct TileConvTensorTypes
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported data type for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::FP16>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using AComputeType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using BComputeType = ck_tile::half_t;
|
||||
using CShuffleDataType = ck_tile::half_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::BF16>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using AComputeType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using BComputeType = ck_tile::bf16_t;
|
||||
using CShuffleDataType = ck_tile::bf16_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::FP32>
|
||||
{
|
||||
using ADataType = float;
|
||||
using AComputeType = float;
|
||||
using BDataType = float;
|
||||
using BComputeType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::I8>
|
||||
{
|
||||
using ADataType = int8_t;
|
||||
using AComputeType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using BComputeType = int8_t;
|
||||
using CShuffleDataType = int8_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = int32_t;
|
||||
using EDataType = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::FP8>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using AComputeType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using BComputeType = ck_tile::fp8_t;
|
||||
using CShuffleDataType = ck_tile::fp8_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck_tile::fp8_t;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
struct TileBlockMNK
|
||||
{
|
||||
int m{};
|
||||
int n{};
|
||||
int k{};
|
||||
};
|
||||
|
||||
struct TileConvBlock
|
||||
{
|
||||
TileBlockMNK per_block = {};
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr TileConvBlock SetTileThreadBlockInfo()
|
||||
{
|
||||
constexpr auto& TB = ALGORITHM.thread_block;
|
||||
return TileConvBlock{
|
||||
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k},
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,158 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
struct TileBlockGemmMNK
|
||||
{
|
||||
int m{};
|
||||
int n{};
|
||||
int k{};
|
||||
};
|
||||
|
||||
struct TileBlockGemmSpec
|
||||
{
|
||||
TileBlockGemmMNK warps = {};
|
||||
TileBlockGemmMNK warp_tile = {};
|
||||
|
||||
bool double_smem_buffer = false;
|
||||
int num_wave_groups = 1;
|
||||
|
||||
ck_tile::GemmPipeline pipeline_version;
|
||||
ck_tile::GemmPipelineScheduler scheduler;
|
||||
};
|
||||
|
||||
struct TileOptimizations
|
||||
{
|
||||
int num_groups_to_merge = 1;
|
||||
bool split_image = false;
|
||||
bool explicit_gemm = false;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::GemmPipelineScheduler SetTileScheduler()
|
||||
{
|
||||
constexpr auto scheduler = ALGORITHM.block_gemm.scheduler;
|
||||
using ck_tile_sched = ck_tile::GemmPipelineScheduler;
|
||||
switch(scheduler)
|
||||
{
|
||||
case PipelineScheduler::DEFAULT: return ck_tile_sched::Default;
|
||||
case PipelineScheduler::INTERWAVE: return ck_tile_sched::Interwave;
|
||||
case PipelineScheduler::INTRAWAVE: return ck_tile_sched::Intrawave;
|
||||
default: throw "Unknown PipelineScheduler";
|
||||
}
|
||||
}
|
||||
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct TilePipelineType
|
||||
{
|
||||
static_assert(false, "Unknown PipelineScheduler");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::BASIC_V1>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.block_gemm.pipeline_version;
|
||||
using ck_tile_pipeline = ck_tile::GemmPipeline;
|
||||
switch(version)
|
||||
{
|
||||
case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1;
|
||||
case PipelineVersion::V2: return ck_tile_pipeline::MEMORY;
|
||||
case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3;
|
||||
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
|
||||
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::ConvolutionSpecialization SetTileConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.specialization;
|
||||
using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case TileConvSpecialization::DEFAULT: return ck_tile_conv_spec::Default;
|
||||
case TileConvSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0;
|
||||
case TileConvSpecialization::FILTER_1X1_STRIDE1_PAD0:
|
||||
return ck_tile_conv_spec::Filter1x1Stride1Pad0;
|
||||
case TileConvSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3;
|
||||
default: throw "Unknown ConvFwdSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval TileBlockGemmSpec SetTileBlockGemm()
|
||||
{
|
||||
constexpr auto& BG = ALGORITHM.block_gemm;
|
||||
|
||||
constexpr bool double_smem_buffer = BG.double_smem_buffer;
|
||||
constexpr int num_wave_groups = BG.num_wave_groups;
|
||||
|
||||
constexpr ck_tile::GemmPipeline pipeline_version = SetTileBlockGemmPipelineVersion<ALGORITHM>();
|
||||
constexpr ck_tile::GemmPipelineScheduler scheduler = SetTileScheduler<ALGORITHM>();
|
||||
|
||||
return TileBlockGemmSpec{
|
||||
.warps = {.m = BG.warps.m, .n = BG.warps.n, .k = BG.warps.k},
|
||||
.warp_tile = {.m = BG.warp_tile.m, .n = BG.warp_tile.n, .k = BG.warp_tile.k},
|
||||
.double_smem_buffer = double_smem_buffer,
|
||||
.num_wave_groups = num_wave_groups,
|
||||
.pipeline_version = pipeline_version,
|
||||
.scheduler = scheduler};
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval TileOptimizations SetTileOptimizations()
|
||||
{
|
||||
constexpr auto& OPT = ALGORITHM.optimizations;
|
||||
|
||||
return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge,
|
||||
.split_image = OPT.split_image,
|
||||
.explicit_gemm = OPT.explicit_gemm};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -145,6 +145,15 @@ enum struct GemmSpecialization
|
||||
MNKOPadding
|
||||
};
|
||||
|
||||
// Enums for the CK Tile convolution specialization.
|
||||
enum class TileConvSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_PAD0,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
FILTER_3x3
|
||||
};
|
||||
|
||||
// Enums for the forward convolution specialization.
|
||||
enum class ConvFwdSpecialization
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user