mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Initial conv bwd weight factory.
This commit is contained in:
@@ -27,8 +27,6 @@ concept ThreadBlockDescriptor = requires(T t) {
|
||||
// Concept for parameters that describe a gridwise XDL GEMM problem.
|
||||
template <typename T>
|
||||
concept GridwiseXdlGemmDescriptor = requires(T t) {
|
||||
{ t.ak1 } -> std::convertible_to<size_t>;
|
||||
{ t.bk1 } -> std::convertible_to<size_t>;
|
||||
{ t.m_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.n_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.m_xdl_per_wave } -> std::convertible_to<size_t>;
|
||||
@@ -159,7 +157,17 @@ concept SpecifiesTileThreadBlock = requires {
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseXdlGemm = requires {
|
||||
concept SpecifiesGridwiseFwdXdlGemm = requires {
|
||||
{ T::gridwise_gemm.ak1 } -> std::convertible_to<size_t>;
|
||||
{ T::gridwise_gemm.bk1 } -> std::convertible_to<size_t>;
|
||||
{ T::gridwise_gemm } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseBwdXdlGemm = requires {
|
||||
{ T::gridwise_gemm.k0_per_block } -> std::convertible_to<size_t>;
|
||||
{ T::gridwise_gemm.k1 } -> std::convertible_to<size_t>;
|
||||
{ T::gridwise_gemm } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
@@ -247,6 +255,11 @@ concept SpecifiesFwdConvSpecialization = requires {
|
||||
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesBwdWeightConvSpecialization = requires {
|
||||
{ T::bwd_weight_specialization } -> std::convertible_to<ConvolutionBackwardWeightSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesGemmSpecialization = requires {
|
||||
{ T::gemm_specialization } -> std::convertible_to<GemmSpecialization>;
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k0_per_block,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::InComputeType,
|
||||
typename Types::WeiComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -60,6 +60,7 @@
|
||||
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weigth_xdl_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
@@ -88,34 +89,43 @@ concept IsTileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock
|
||||
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
|
||||
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesDataTransfer =
|
||||
SpecifiesThreadBlock<T> && SpecifiesBlockTransfer<T> &&
|
||||
SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T>;
|
||||
|
||||
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
|
||||
template <typename T>
|
||||
concept IsXdlV3Algorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
concept IsFwdXdlV3Algorithm = ConvAlgorithmDescriptor<T> &&
|
||||
SpecifiesDataTransfer<T> && SpecifiesGridwiseFwdXdlGemm<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
|
||||
// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
|
||||
// Standard XDL-based fwd kernel (uses XDLops hardware instructions for matrix multiply)
|
||||
template <typename T>
|
||||
concept IsXdlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
|
||||
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
|
||||
concept IsFwdXdlAlgorithm = ConvAlgorithmDescriptor<T> &&
|
||||
SpecifiesDataTransfer<T> && SpecifiesGridwiseFwdXdlGemm<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
|
||||
SpecifiesLoopScheduler<T>;
|
||||
|
||||
// Standard XDL-based bwd weight kernel (uses XDLops hardware instructions for matrix multiply)
|
||||
template <typename T>
|
||||
concept IsBwdXdlAlgorithm = ConvAlgorithmDescriptor<T> &&
|
||||
SpecifiesDataTransfer<T> && SpecifiesGridwiseBwdXdlGemm<T> &&
|
||||
SpecifiesBwdWeightConvSpecialization<T> && SpecifiesTransposeTransfer<T>;
|
||||
|
||||
// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions)
|
||||
template <typename T>
|
||||
concept IsWmmaAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
concept IsFwdWmmaAlgorithm = ConvAlgorithmDescriptor<T> &&
|
||||
SpecifiesDataTransfer<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts
|
||||
template <typename T>
|
||||
concept IsDlAlgorithm =
|
||||
concept IsFwdDlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
@@ -139,19 +149,19 @@ constexpr auto make_conv_instance()
|
||||
}
|
||||
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>)
|
||||
if constexpr(IsFwdXdlV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsXdlAlgorithm<AlgoType>)
|
||||
else if constexpr(IsFwdXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsWmmaAlgorithm<AlgoType>)
|
||||
else if constexpr(IsFwdWmmaAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsDlAlgorithm<AlgoType>)
|
||||
else if constexpr(IsFwdDlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
@@ -177,10 +187,17 @@ constexpr auto make_conv_instance()
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"Backward weight convolution is not yet supported. "
|
||||
"Only forward convolution (ConvDirection::FORWARD) is currently implemented.");
|
||||
if constexpr (IsBwdXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"Backward weight convolution is not yet supported. "
|
||||
"Only forward convolution (ConvDirection::FORWARD) is currently implemented.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -24,7 +24,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdLargeTensorFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -62,6 +62,7 @@ consteval auto GetElementwiseOp()
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
requires ConvDirectionIsForward<Sig>
|
||||
struct ElementwiseOps
|
||||
{
|
||||
static constexpr auto input_op = GetElementwiseOp<Sig.input>();
|
||||
@@ -72,4 +73,16 @@ struct ElementwiseOps
|
||||
using CDEElementwiseOp = typename decltype(output_op)::Op;
|
||||
};
|
||||
|
||||
template <auto Sig>
|
||||
requires ConvDirectionIsBackwardWeight<Sig>
|
||||
struct ElementwiseOps
|
||||
{
|
||||
static constexpr auto input_op = GetElementwiseOp<Sig.input>();
|
||||
static constexpr auto weight_op = GetElementwiseOp<Sig.weight>();
|
||||
static constexpr auto output_op = GetElementwiseOp<Sig.output>();
|
||||
using InElementwiseOp = typename decltype(input_op)::Op;
|
||||
using WeiElementwiseOp = typename decltype(weight_op)::Op;
|
||||
using OutElementwiseOp = typename decltype(output_op)::Op;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -216,18 +216,31 @@ consteval auto GetAuxiliaryTensorLayouts()
|
||||
return EmptyAuxiliaryTensorLayout{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> &&
|
||||
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM> &&
|
||||
ConvDirectionIsForward<Signature>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported.");
|
||||
using ALayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
|
||||
using BLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
|
||||
using ELayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM, DIR>())::type;
|
||||
};
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> &&
|
||||
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM> &&
|
||||
ConvDirectionIsBackwardWeight<Signature>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
using InLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
|
||||
using WeiLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
|
||||
using OutLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -151,6 +151,7 @@ consteval auto GetAuxiliaryTensorDataTypes()
|
||||
}
|
||||
|
||||
template <auto Signature>
|
||||
requires ConvDirectionIsForward<Signature>
|
||||
struct FwdConvTensorDataTypes
|
||||
{
|
||||
static constexpr auto input_types =
|
||||
@@ -176,4 +177,24 @@ struct FwdConvTensorDataTypes
|
||||
using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes<Signature>())::type;
|
||||
};
|
||||
|
||||
template <auto Signature>
|
||||
requires ConvDirectionIsBackwardWeight<Signature>
|
||||
struct FwdConvTensorDataTypes
|
||||
{
|
||||
static constexpr auto input_types =
|
||||
GetTensorDataAndComputeTypes<Signature.input.config, Signature.data_type>();
|
||||
static constexpr auto weight_types =
|
||||
GetTensorDataAndComputeTypes<Signature.weight.config, Signature.data_type>();
|
||||
static constexpr auto output_types =
|
||||
GetTensorDataAndComputeTypes<Signature.output.config, Signature.data_type>();
|
||||
|
||||
using InDataType = typename decltype(input_types.first)::type;
|
||||
using InComputeType = typename decltype(input_types.second)::type;
|
||||
using WeiDataType = typename decltype(weight_types.first)::type;
|
||||
using WeiComputeType = typename decltype(weight_types.second)::type;
|
||||
using AccDataType =
|
||||
typename decltype(GetTensorAccumulationType<Signature.accumulation_data_type,
|
||||
Signature.data_type>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -149,12 +149,27 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
|
||||
case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC;
|
||||
default: throw "Unknown ConvFwdSpecialization";
|
||||
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
|
||||
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
|
||||
default: throw "Unsupported ConvSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization SetBwdWeightConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.bwd_specialization;
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
|
||||
default: throw "Unsupported ConvSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -192,8 +192,8 @@ enum class TileConvSpecialization
|
||||
FILTER_3x3
|
||||
};
|
||||
|
||||
// Enums for the forward convolution specialization.
|
||||
enum class ConvFwdSpecialization
|
||||
// Enums for the convolution specializations.
|
||||
enum class ConvSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_PAD0,
|
||||
|
||||
@@ -20,7 +20,7 @@ constexpr auto SIGNATURE =
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::FwdTransfer_4x64x1)
|
||||
@@ -34,7 +34,7 @@ using Instance = Builder::Instance;
|
||||
TEST(BwdWeight_2DFp16_CShufV3_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"Intrawave",
|
||||
|
||||
@@ -28,18 +28,30 @@ struct ThreadBlock
|
||||
};
|
||||
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
|
||||
|
||||
// Describe gridwise XDL GEMM parameters.
|
||||
struct GridwiseXdlGemm
|
||||
struct XdlParams
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
size_t m_per_xdl = 0;
|
||||
size_t n_per_xdl = 0;
|
||||
size_t m_xdl_per_wave = 0;
|
||||
size_t n_xdl_per_wave = 0;
|
||||
};
|
||||
static_assert(ckb::GridwiseXdlGemmDescriptor<GridwiseXdlGemm>);
|
||||
static_assert(ckb::GridwiseXdlGemmDescriptor<XdlParams>);
|
||||
|
||||
// Describe gridwise XDL GEMM parameters.
|
||||
struct GridwiseFwdXdlGemm : public XdlParams
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
};
|
||||
static_assert(ckb::SpecifiesGridwiseFwdXdlGemm<GridwiseFwdXdlGemm>);
|
||||
|
||||
struct GridwiseBwdXdlGemm : public XdlParams
|
||||
{
|
||||
size_t k0_per_block = 0;
|
||||
size_t k1 = 0;
|
||||
};
|
||||
static_assert(ckb::SpecifiesGridwiseBwdXdlGemm<GridwiseFwdXdlGemm>);
|
||||
|
||||
// Describe gridwise WMMA GEMM parameters.
|
||||
struct GridwiseWmmaGemm
|
||||
@@ -169,9 +181,14 @@ struct ThreadBlock_
|
||||
ThreadBlock thread_block;
|
||||
};
|
||||
|
||||
struct XdlGemm_
|
||||
struct FwdXdlGemm_
|
||||
{
|
||||
GridwiseXdlGemm gridwise_gemm;
|
||||
GridwiseFwdXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct BwdXdlGemm_
|
||||
{
|
||||
GridwiseBwdXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct WmmaGemm_
|
||||
@@ -184,12 +201,17 @@ struct Transfer_
|
||||
TransferABC transfer;
|
||||
};
|
||||
|
||||
struct ConvSpecialization_
|
||||
struct ConvSpecializationFwd_
|
||||
{
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
ConvSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
};
|
||||
|
||||
struct ConvSpecializationBwdWeight_
|
||||
{
|
||||
ConvSpecialization bwd_specialization;
|
||||
};
|
||||
|
||||
struct Prefetch_
|
||||
{
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
@@ -197,6 +219,12 @@ struct Prefetch_
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
|
||||
struct TransposeParams_
|
||||
{
|
||||
size_t max_transpose_transfer_src_scalar_per_vector{1};
|
||||
size_t max_transpose_transfer_dst_scalar_per_vector{1};
|
||||
};
|
||||
|
||||
struct BlockGemm_
|
||||
{
|
||||
BlockGemm block_gemm;
|
||||
@@ -329,7 +357,11 @@ struct ConvAlgorithmTemplate : Components...
|
||||
constexpr auto with_gemm_config(const GemmConfig& gemm) const
|
||||
{
|
||||
auto result = *this;
|
||||
if constexpr(std::is_base_of_v<XdlGemm_, ConvAlgorithmTemplate>)
|
||||
if constexpr(std::is_base_of_v<FwdXdlGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
if constexpr(std::is_base_of_v<BwdXdlGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
@@ -359,6 +391,14 @@ struct ConvAlgorithmTemplate : Components...
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_specializations(ConvBwdWeightSpecialization bwd_spec) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<ConvSpecializationBwdWeight_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.bwd_specialization = bwd_spec;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_prefetch_config(size_t k_prefetch_stages,
|
||||
size_t groups_to_merge,
|
||||
PipelineScheduler scheduler) const
|
||||
@@ -371,6 +411,16 @@ struct ConvAlgorithmTemplate : Components...
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_transpose_params(bool max_src_scalar_per_vector,
|
||||
bool max_dst_scalar_per_vector) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TransposeParams_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.max_transpose_transfer_src_scalar_per_vector = max_src_scalar_per_vector;
|
||||
result.max_transpose_transfer_dst_scalar_per_vector = max_dst_scalar_per_vector;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BG>
|
||||
constexpr auto with_block_gemm(const BG& bg) const
|
||||
{
|
||||
@@ -456,16 +506,17 @@ struct ConvAlgorithmTemplate : Components...
|
||||
// Algorithm types
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_, ConvSpecializationFwd_, Prefetch_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, BlockGemm_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_, ConvSpecializationFwd_, BlockGemm_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_, ConvSpecializationFwd_, Prefetch_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
ConvSpecialization_,
|
||||
ConvSpecializationFwd_,
|
||||
DlThreadConfig_,
|
||||
DlThreadCluster_,
|
||||
DlTransfer_>;
|
||||
@@ -479,4 +530,7 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileTh
|
||||
TileConvSpecialization_,
|
||||
TileOptimizations_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_, ConvSpecializationBwdWeight_, TransposeParams_>;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
Reference in New Issue
Block a user