[rocm-libraries] ROCm/rocm-libraries#4582 (commit 990a00d)

[CK_Builder] added bwd data kernels to builder factory
 (#4582)

This PR adds bwd data wmma and xdl kernels to the ck builder, their
instance and conv traits as well as tests for the above.
This commit is contained in:
kabrahamAMD
2026-02-27 03:06:29 +00:00
committed by assistant-librarian[bot]
parent c8a8449eec
commit 5e06874aae
34 changed files with 2511 additions and 104 deletions

View File

@@ -47,11 +47,17 @@ concept BlockGemmPipelineDescriptor = requires(T t) {
// Concept for parameters that describe a gridwise WMMA GEMM problem.
template <typename T>
concept GridwiseWmmaGemmDescriptor = requires(T t) {
{ t.k1 } -> SizeType;
{ t.m_per_wmma } -> SizeType;
{ t.n_per_wmma } -> SizeType;
{ t.m_wmma_per_wave } -> SizeType;
{ t.n_wmma_per_wave } -> SizeType;
(
requires { { T::k1 } -> SizeType; } ||
(requires { { T::ak1 } -> SizeType; } &&
requires { { T::bk1 } -> SizeType; })
) &&
requires {
{ T::m_per_wmma } -> SizeType;
{ T::n_per_wmma } -> SizeType;
{ T::m_wmma_per_wave } -> SizeType;
{ T::n_wmma_per_wave } -> SizeType;
};
};
// Concept for vectorized data transfer for convolution input tensors.
@@ -187,6 +193,14 @@ concept GridwiseBwdXdlGemmDescriptor = requires(T t) {
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept GridwiseBwdDataXdlGemmDescriptor = requires(T t) {
{ t.ak1 } -> SizeType;
{ t.bk1 } -> SizeType;
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseFwdXdlGemm = requires(T t) {
@@ -199,6 +213,12 @@ concept SpecifiesGridwiseBwdXdlGemm = requires(T t) {
{ t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseBwdDataXdlGemm = requires(T t) {
{ t.gridwise_gemm } -> GridwiseBwdDataXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise WMMA GEMM info.
template <typename T>
concept SpecifiesGridwiseWmmaGemm = requires(T t) {
@@ -292,6 +312,11 @@ concept SpecifiesBwdWeightConvSpecialization = requires {
{ T::bwd_weight_specialization } -> std::convertible_to<ConvSpecialization>;
};
template <typename T>
concept SpecifiesBwdDataConvSpecialization = requires {
{ T::bwd_data_specialization } -> std::convertible_to<ConvSpecialization>;
};
template <typename T>
concept SpecifiesGemmSpecialization = requires {
{ T::gemm_specialization } -> std::convertible_to<GemmSpecialization>;

View File

@@ -28,24 +28,29 @@ concept FwdXdlAlgorithmBase =
template <typename T>
concept BwdXdlAlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters4D<T> &&
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
(SpecifiesTileTransferParameters4D<T> || SpecifiesTileTransferParameters3D<T>) &&
(SpecifiesGridwiseBwdXdlGemm<T> || SpecifiesGridwiseBwdDataXdlGemm<T>) &&
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>);
template <typename T>
concept BwdXdlV3AlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
(SpecifiesGridwiseBwdXdlGemm<T> || SpecifiesGridwiseBwdDataXdlGemm<T>) &&
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>) &&
SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
template <typename T>
concept BwdWmmaAlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
SpecifiesGridwiseWmmaGemm<T> &&
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>);
template <typename T>
concept BwdWmmaV3AlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
SpecifiesGridwiseWmmaGemm<T> &&
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>) &&
SpecifiesBlockGemm<T>;
// Reference algorithm concept
@@ -107,6 +112,9 @@ concept BwdWmmaAlgorithm =
BwdWmmaAlgorithmBase<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T> &&
SpecifiesGridwiseGemmPipeline<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept BwdMultiDWmmaAlgorithm = BwdWmmaAlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
template <typename T>
concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase<T> && SpecifiesMultipleDSupport<T>;

View File

@@ -0,0 +1,115 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdDataMultipleD_wmma_CShuffle_v3 instance
// of a grouped bwd Data convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardData<SIGNATURE>
struct ConvBwdDataMultiDWmmaV3Factory
{
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdDataConvSpecialization<ALGORITHM>();
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The backward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3<
SPATIAL_DIM,
typename Layouts::OutLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::InLayout,
typename Types::OutDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::InDataType,
typename Ops::OutElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::InElementwiseOp,
BWD_CONV_SPECIALIZATION,
ALGORITHM.DoPadGemmM,
ALGORITHM.DoPadGemmN,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
ck::Sequence<C_BLOCK_TRANSFER.scalar_per_vector,
C_BLOCK_TRANSFER.scalar_per_vector,
C_BLOCK_TRANSFER.scalar_per_vector>,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
typename Types::OutComputeType,
typename Types::InComputeType,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,107 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdDataMultipleD_wmma_CShuffle instance
// of a grouped bwd Data convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardData<SIGNATURE>
struct ConvBwdDataMultiDWmmaFactory
{
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdDataConvSpecialization<ALGORITHM>();
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The backward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<
SPATIAL_DIM,
typename Layouts::OutLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::InLayout,
typename Types::OutDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::InDataType,
typename Ops::OutElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::InElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
ALGORITHM.num_gemm_k_prefetch_stages,
LOOP_SCHEDULER,
GRIDWISE_GEMM_PIPELINE_VERSION>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,113 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_V1 instance
// of a grouped bwd Data convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardData<SIGNATURE>
struct ConvBwdDataMultiDXdlFactory
{
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdDataConvSpecialization<ALGORITHM>();
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The backward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<
SPATIAL_DIM,
typename Layouts::OutLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::InLayout,
typename Types::OutDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::InDataType,
typename Ops::OutElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::InElementwiseOp,
BWD_CONV_SPECIALIZATION,
ALGORITHM.DoPadGemmM,
ALGORITHM.DoPadGemmN,
ALGORITHM.num_gemm_k_prefetch_stages,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
LOOP_SCHEDULER,
typename Types::OutComputeType,
typename Types::InComputeType,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -77,6 +77,9 @@
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_cshuffle_v3_factory.hpp"
namespace ck_tile::builder::factory {
@@ -148,13 +151,32 @@ constexpr auto make_conv_instance()
"WMMA, DL (NHWC layout), or Large Tensor variant.");
}
}
// Backward data direction (will expand with more algorithms in the future)
// Backward data direction
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
static_assert(false,
"Backward data convolution: Only reference and tile algorithms supported "
"currently. "
"Optimized kernels (XDL, WMMA, etc.) not yet implemented.");
if constexpr(BwdMultiDXdlAlgorithm<AlgoType>)
{
return typename ConvBwdDataMultiDXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdMultiDWmmaV3Algorithm<AlgoType>)
{
return
typename ConvBwdDataMultiDWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdMultiDWmmaAlgorithm<AlgoType>)
{
return typename ConvBwdDataMultiDWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else
{
static_assert(
false,
"No suitable backward data convolution kernel factory found for the provided "
"ALGORITHM. "
"The ALGORITHM must satisfy requirements for one of: Reference, XDL multiple d, "
"Wmma multiple d, "
"or WMMA multiple d v3.");
}
}
// Backward weight direction (will expand with more algorithms in the future)
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)

View File

@@ -5,6 +5,7 @@
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
@@ -186,4 +187,24 @@ SetBwdWeightConvSpecialization()
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
SetBwdDataConvSpecialization()
{
constexpr auto specialization = ALGORITHM.bwd_data_specialization;
using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
switch(specialization)
{
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
case ConvSpecialization::FILTER_1X1_PAD0:
throw "FILTER_1x1_PAD0 is not supported for backward data convolution.";
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvSpecialization::ODD_C:
throw "FILTER ODD_C is not supported for backward data convolution.";
case ConvSpecialization::FILTER_3x3:
throw "FILTER_3x3 is not supported for backward data convolution.";
default: throw "Unsupported ConvSpecialization";
}
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -90,6 +90,10 @@ class ConvDescription : public Description
2, "Gemm padding: ", traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT));
else
f.writeLine(2, "Struct does not contain optional gemm_padding argument");
if(traits_.do_pad_gemm_m)
f.writeLine(2, "Do Padd Gemm M: ", traits_.do_pad_gemm_m.value_or(false));
if(traits_.do_pad_gemm_n)
f.writeLine(2, "Do Padd Gemm N: ", traits_.do_pad_gemm_n.value_or(false));
f.writeLine(2, "Convolution specialization: ", traits_.conv_specialization);
// Pipeline section
f.writeLine(2, "Pipeline version: ", traits_.pipeline_version);
@@ -103,7 +107,7 @@ class ConvDescription : public Description
traits_.warp_gemm.n_iter);
// Memory Access section
f.writeLast(2, "Memory access:");
f.writeLine(2, "Memory access:");
f.writeLine(3, "A Tile transfer: ");
f.writeLine(4,
@@ -196,7 +200,7 @@ class ConvDescription : public Description
traits_.c_tile_transfer.thread_cluster_dims[2],
"×",
traits_.c_tile_transfer.thread_cluster_dims[3]);
f.writeLine(4,
f.writeLast(4,
"Vector access (GMEM write) instruction size: ",
traits_.c_tile_transfer.scalar_per_vector);
if(traits_.num_gemm_k_prefetch_stage)
@@ -215,14 +219,14 @@ class ConvDescription : public Description
f.writeLine(2,
"Struct does not contain optional "
"max_transpose_transfer_src_scalar_per_vector parameter");
if(traits_.max_transpose_dst_scalar_per_vector)
if(traits_.max_transpose_transfer_dst_scalar_per_vector)
f.writeLine(2,
"Max Transpose dst scalar per vector: ",
traits_.max_transpose_dst_scalar_per_vector.value_or(0));
traits_.max_transpose_transfer_dst_scalar_per_vector.value_or(0));
else
f.writeLine(
2,
"Struct does not contain optional max_transpose_dst_scalar_per_vector parameter");
f.writeLine(2,
"Struct does not contain optional "
"max_transpose_transfer_dst_scalar_per_vector parameter");
if(traits_.num_groups_to_merge)
f.writeLast(2, "Num groups to merge: ", traits_.num_groups_to_merge.value_or(0));
else

View File

@@ -108,8 +108,10 @@ struct ConvTraits
builder::PipelineScheduler pipeline_scheduler;
std::optional<int> max_transpose_transfer_src_scalar_per_vector = std::nullopt;
std::optional<int> max_transpose_dst_scalar_per_vector = std::nullopt;
std::optional<int> max_transpose_transfer_dst_scalar_per_vector = std::nullopt;
std::optional<int> num_groups_to_merge = std::nullopt;
std::optional<bool> do_pad_gemm_m = std::nullopt;
std::optional<bool> do_pad_gemm_n = std::nullopt;
};
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Wmma_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(
InstTraits::kCDEShuffleBlockTransferScalarPerVector_NPerBlock),
.num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,53 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Wmma_CShuffle_V3_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_V3_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(
InstTraits::kCDEShuffleBlockTransferScalarPerVector_NPerBlock[0]),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
.max_transpose_transfer_dst_scalar_per_vector =
InstTraits::kMaxTransposeTransferDstScalarPerVector,
.do_pad_gemm_m = InstTraits::kDoPadGemmM,
.do_pad_gemm_n = InstTraits::kDoPadGemmN,
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,60 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Xdl_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdData_multiple_d_Xdl_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer =
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
.max_transpose_transfer_dst_scalar_per_vector =
InstTraits::kMaxTransposeTransferDstScalarPerVector,
.do_pad_gemm_m = InstTraits::kDoPadGemmM,
.do_pad_gemm_n = InstTraits::kDoPadGemmN,
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -42,8 +42,9 @@ constexpr ConvTraits instance_to_conv_traits()
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
.max_transpose_transfer_dst_scalar_per_vector =
InstTraits::kTransposeTransferDstScalarPerVector,
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
};
}

View File

@@ -49,8 +49,9 @@ constexpr ConvTraits instance_to_conv_traits()
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
.max_transpose_transfer_dst_scalar_per_vector =
InstTraits::kTransposeTransferDstScalarPerVector,
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
};
}

View File

@@ -42,7 +42,8 @@ constexpr ConvTraits instance_to_conv_traits()
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector,
.max_transpose_transfer_dst_scalar_per_vector =
InstTraits::kMaxTransposeTransferDstScalarPerVector,
};
}

View File

@@ -49,7 +49,8 @@ constexpr ConvTraits instance_to_conv_traits()
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector,
.max_transpose_transfer_dst_scalar_per_vector =
InstTraits::kMaxTransposeTransferDstScalarPerVector,
};
}

View File

@@ -796,7 +796,8 @@ constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params()
}
template <typename InstTraits>
constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer()
constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer(
ck::index_t CDEBlockTansferScalarPerVector = InstTraits::kCDEBlockTransferScalarPerVector)
{
return OutputTileTransferInfo{
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle,
@@ -805,7 +806,7 @@ constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer()
InstTraits::kCDEThreadClusterLengths[1],
InstTraits::kCDEThreadClusterLengths[2],
InstTraits::kCDEThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector};
.scalar_per_vector = CDEBlockTansferScalarPerVector};
}
template <typename InstTraits>

View File

@@ -18,3 +18,8 @@
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
// Bwd data instances
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"

View File

@@ -0,0 +1,315 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
namespace ck::tensor_operation::device {
template <index_t NDimSpatial,
typename OutLayout, // output image
typename WeiLayout, // weight
typename DsLayout, // bias
typename InLayout, // input image
typename OutDataType, // output image
typename WeiDataType, // weight
typename AccDataType,
typename CShuffleDataType,
typename DsDataType, // bias
typename InDataType, // input image
typename OutElementwiseOp, // output image
typename WeiElementwiseOp, // weight
typename InElementwiseOp, // C, bias, and input image
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage,
ck::LoopScheduler LoopSched,
ck::PipelineVersion PipelineVer>
struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle;
} // namespace ck::tensor_operation::device
namespace ck_tile::reflect {
/// @brief Tag type for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle device kernel
struct DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag;
template <index_t NDimSpatial,
typename OutLayout_, // output image
typename WeiLayout_, // weight
typename DsLayout_, // bias
typename InLayout_, // input image
typename OutDataType_, // output image
typename WeiDataType_, // weight
typename AccDataType_,
typename CShuffleDataType_,
typename DsDataType_, // bias
typename InDataType_, // input image
typename OutElementwiseOp_, // output image
typename WeiElementwiseOp_, // weight
typename InElementwiseOp_, // C, bias, and input image
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1_,
typename ABlockTransferThreadClusterArrangeOrder_,
typename ABlockTransferSrcAccessOrder_,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1_,
typename BBlockTransferThreadClusterArrangeOrder_,
typename BBlockTransferSrcAccessOrder_,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage,
ck::LoopScheduler LoopSched,
ck::PipelineVersion PipelineVer>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<
NDimSpatial,
OutLayout_, // output image
WeiLayout_, // weight
DsLayout_, // bias
InLayout_, // input image
OutDataType_, // output image
WeiDataType_, // weight
AccDataType_,
CShuffleDataType_,
DsDataType_, // bias
InDataType_, // input image
OutElementwiseOp_, // output image
WeiElementwiseOp_, // weight
InElementwiseOp_, // C, bias, and input image
ConvBackwardDataSpecialization,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
MPerWMMA,
NPerWMMA,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1_,
ABlockTransferThreadClusterArrangeOrder_,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1_,
BBlockTransferThreadClusterArrangeOrder_,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
NumGemmKPrefetchStage,
LoopSched,
PipelineVer>>
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
using OutLayout = OutLayout_;
using DsLayout = DsLayout_;
using InDataType = InDataType_;
using WeiDataType = WeiDataType_;
using OutDataType = OutDataType_;
using AccDataType = AccDataType_;
using DsDataType = DsDataType_;
using InElementwiseOperation = InElementwiseOp_;
using WeiElementwiseOperation = WeiElementwiseOp_;
using OutElementwiseOperation = OutElementwiseOp_;
static constexpr auto kConvBwdDataSpecialization = ConvBackwardDataSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
static constexpr ck::index_t kK1 = K1;
static constexpr ck::index_t kMPerWmma = MPerWMMA;
static constexpr ck::index_t kNPerWmma = NPerWMMA;
static constexpr ck::index_t kMRepeat = MRepeat;
static constexpr ck::index_t kNRepeat = NRepeat;
static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
static constexpr ck::index_t kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle;
static constexpr ck::index_t kCDEShuffleBlockTransferScalarPerVector_NPerBlock =
CDEShuffleBlockTransferScalarPerVector_NPerBlock;
static constexpr ck::PipelineVersion kPipelineVer = PipelineVer;
static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage;
using ABlockTransferThreadClusterLengths_AK0_M_AK1 =
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_AK1;
static constexpr bool kABlockLdsExtraM = ABlockLdsExtraM;
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_BK1;
static constexpr bool kBBlockLdsExtraN = BBlockLdsExtraN;
using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << kTensorOpName;
// Template parameters in exact order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<OutLayout>(); // 2. OutLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<InLayout>(); // 5. InLayout
oss << "," << detail::type_name<OutDataType>(); // 6. OutDataType
oss << "," << detail::type_name<WeiDataType>(); // 7. WeiDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::tuple_name<DsDataType>(); // 9. DsDataType
oss << "," << detail::type_name<InDataType>(); // 10. InDataType
oss << ","
<< detail::elementwise_op_name<OutElementwiseOperation>(); // 11.
// OutElementwiseOperation
oss << ","
<< detail::elementwise_op_name<WeiElementwiseOperation>(); // 12.
// WeiElementwiseOperation
oss << ","
<< detail::elementwise_op_name<InElementwiseOperation>(); // 13. InElementwiseOperation
oss << ","
<< detail::conv_bwd_data_spec_name(
kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kK0PerBlock; // 18. K0PerBlock
oss << "," << kK1; // 19. ABK1
oss << "," << kMPerWmma; // 20. MPerWmma
oss << "," << kNPerWmma; // 21. NPerWmma
oss << "," << kMRepeat; // 22. MRepeat
oss << "," << kNRepeat; // 23. NRepeat
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 24.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
oss << "," << kABlockTransferSrcVectorDim; // 27.
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 31.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
oss << "," << kBBlockTransferSrcVectorDim; // 34.
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
oss << "," << kCShuffleMRepeatPerShuffle; // 38.
oss << "," << kCShuffleNRepeatPerShuffle; // 39.
oss << ","
<< detail::sequence_name<
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40.
oss << "," << kCDEShuffleBlockTransferScalarPerVector_NPerBlock; // 41.
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 42. LoopSched
oss << "," << kNumGemmKPrefetchStage; // 43.
oss << "," << detail::pipeline_version_name(kPipelineVer); // 44.
oss << ">";
return oss.str();
}
};
} // namespace ck_tile::reflect

View File

@@ -0,0 +1,350 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
namespace ck::tensor_operation::device {
template <index_t NDimSpatial,
typename OutLayout, // output image
typename WeiLayout, // weight
typename DsLayout, // bias
typename InLayout, // input image
typename OutDataType, // output image
typename WeiDataType, // weight
typename AccDataType,
typename CShuffleDataType,
typename DsDataType, // bias
typename InDataType, // input image
typename OutElementwiseOp, // output image
typename WeiElementwiseOp, // weight
typename InElementwiseOp, // C, bias, and input image
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization,
bool DoPadGemmM,
bool DoPadGemmN,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVector_NPerBlock,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
typename ComputeTypeB,
ck::index_t max_transpose_transfer_src_scalar_per_vector,
ck::index_t max_transpose_transfer_dst_scalar_per_vector>
struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3;
} // namespace ck::tensor_operation::device
namespace ck_tile::reflect {
/// @brief Tag type for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3 device kernel
struct DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_V3_Tag;
template <index_t NDimSpatial,
typename OutLayout_, // output image
typename WeiLayout_, // weight
typename DsLayout_, // bias
typename InLayout_, // input image
typename OutDataType_, // output image
typename WeiDataType_, // weight
typename AccDataType_,
typename CShuffleDataType_,
typename DsDataType_, // bias
typename InDataType_, // input image
typename OutElementwiseOp_, // output image
typename WeiElementwiseOp_, // weight
typename InElementwiseOp_, // C, bias, and input image
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization,
bool DoPadGemmM,
bool DoPadGemmN,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1_,
typename ABlockTransferThreadClusterArrangeOrder_,
typename ABlockTransferSrcAccessOrder_,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1_,
typename BBlockTransferThreadClusterArrangeOrder_,
typename BBlockTransferSrcAccessOrder_,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
typename CDEShuffleBlockTransferScalarPerVector_NPerBlock_,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA_,
typename ComputeTypeB_,
ck::index_t max_transpose_transfer_src_scalar_per_vector,
ck::index_t max_transpose_transfer_dst_scalar_per_vector>
struct InstanceTraits<
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3<
NDimSpatial,
OutLayout_, // output image
WeiLayout_, // weight
DsLayout_, // bias
InLayout_, // input image
OutDataType_, // output image
WeiDataType_, // weight
AccDataType_,
CShuffleDataType_,
DsDataType_, // bias
InDataType_, // input image
OutElementwiseOp_, // output image
WeiElementwiseOp_, // weight
InElementwiseOp_, // C, bias, and input image
ConvBackwardDataSpecialization,
DoPadGemmM,
DoPadGemmN,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
AK1,
BK1,
MPerWMMA,
NPerWMMA,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1_,
ABlockTransferThreadClusterArrangeOrder_,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1_,
BBlockTransferThreadClusterArrangeOrder_,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
CDEShuffleBlockTransferScalarPerVector_NPerBlock_,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA_,
ComputeTypeB_,
max_transpose_transfer_src_scalar_per_vector,
max_transpose_transfer_dst_scalar_per_vector>>
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3";
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_V3_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
using OutLayout = OutLayout_;
using DsLayout = DsLayout_;
using InDataType = InDataType_;
using WeiDataType = WeiDataType_;
using OutDataType = OutDataType_;
using AccDataType = AccDataType_;
using DsDataType = DsDataType_;
using InElementwiseOperation = InElementwiseOp_;
using WeiElementwiseOperation = WeiElementwiseOp_;
using OutElementwiseOperation = OutElementwiseOp_;
static constexpr auto kConvBwdDataSpecialization = ConvBackwardDataSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
static constexpr ck::index_t kAK1 = AK1;
static constexpr ck::index_t kBK1 = BK1;
static constexpr ck::index_t kMPerWmma = MPerWMMA;
static constexpr ck::index_t kNPerWmma = NPerWMMA;
static constexpr ck::index_t kMRepeat = MRepeat;
static constexpr ck::index_t kNRepeat = NRepeat;
static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
static constexpr ck::index_t kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle;
static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector =
max_transpose_transfer_src_scalar_per_vector;
static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector =
max_transpose_transfer_dst_scalar_per_vector;
static constexpr bool kDoPadGemmM = DoPadGemmM;
static constexpr bool kDoPadGemmN = DoPadGemmN;
using CDEShuffleBlockTransferScalarPerVector_NPerBlock =
CDEShuffleBlockTransferScalarPerVector_NPerBlock_;
static constexpr auto kCDEShuffleBlockTransferScalarPerVector_NPerBlock =
detail::SequenceToArray<CDEShuffleBlockTransferScalarPerVector_NPerBlock>::value;
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
using ABlockTransferThreadClusterLengths_AK0_M_AK1 =
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_AK1;
static constexpr bool kABlockLdsExtraM = ABlockLdsExtraM;
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_BK1;
static constexpr bool kBBlockLdsExtraN = BBlockLdsExtraN;
using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
using ComputeTypeA = ComputeTypeA_;
using ComputeTypeB = ComputeTypeB_;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << kTensorOpName;
// Template parameters in exact order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<OutLayout>(); // 2. OutLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<InLayout>(); // 5. InLayout
oss << "," << detail::type_name<OutDataType>(); // 6. OutDataType
oss << "," << detail::type_name<WeiDataType>(); // 7. WeiDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::tuple_name<DsDataType>(); // 9. DsDataType
oss << "," << detail::type_name<InDataType>(); // 10. InDataType
oss << ","
<< detail::elementwise_op_name<OutElementwiseOperation>(); // 11.
// OutElementwiseOperation
oss << ","
<< detail::elementwise_op_name<WeiElementwiseOperation>(); // 12.
// WeiElementwiseOperation
oss << ","
<< detail::elementwise_op_name<InElementwiseOperation>(); // 13. InElementwiseOperation
oss << ","
<< detail::conv_bwd_data_spec_name(
kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization
oss << "," << kDoPadGemmM;
oss << "," << kDoPadGemmN;
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kK0PerBlock; // 18. K0PerBlock
oss << "," << kAK1; // 19. ABK1
oss << "," << kBK1; // 19. ABK1
oss << "," << kMPerWmma; // 20. MPerWmma
oss << "," << kNPerWmma; // 21. NPerWmma
oss << "," << kMRepeat; // 22. MRepeat
oss << "," << kNRepeat; // 23. NRepeat
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 24.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
oss << "," << kABlockTransferSrcVectorDim; // 27.
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 31.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
oss << "," << kBBlockTransferSrcVectorDim; // 34.
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
oss << "," << kCShuffleMRepeatPerShuffle; // 38.
oss << "," << kCShuffleNRepeatPerShuffle; // 39.
oss << ","
<< detail::sequence_name<
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40.
oss << "," << kCDEShuffleBlockTransferScalarPerVector_NPerBlock[0]; // 41.
oss << "," << detail::pipeline_scheduler_name(kBlkGemmPipeSched); // 43.
oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 44.
oss << "," << detail::type_name<ComputeTypeA>(); // 45.
oss << "," << detail::type_name<ComputeTypeB>(); // 46.
oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 47.
oss << "," << kMaxTransposeTransferDstScalarPerVector; // 48.
oss << ">";
return oss.str();
}
};
} // namespace ck_tile::reflect

View File

@@ -0,0 +1,345 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
namespace ck::tensor_operation::device {
template <ck::index_t NDimSpatial,
typename OutLayout,
typename WeiLayout,
typename DsLayout,
typename InLayout,
typename OutDataType,
typename WeiDataType,
typename AccDataType,
typename OutComputeType,
typename DsDataType,
typename InDataType,
typename OutElementwiseOperation,
typename WeiElementwiseOperation,
typename InElementwiseOperation,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization,
bool do_pad_gemm_m,
bool do_pad_gemm_n,
ck::index_t num_gemm_k_prefetch_stages,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
ck::index_t BBlockLdsAddExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
ck::LoopScheduler LoopSched,
typename ComputeTypeA,
typename ComputeTypeB,
ck::index_t max_transpose_transfer_src_scalar_per_vector,
ck::index_t max_transpose_transfer_dst_scalar_per_vector>
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
} // namespace ck::tensor_operation::device
namespace ck_tile::reflect {
/// @brief Tag type for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle device kernel
struct DeviceGroupedConvBwdData_multiple_d_Xdl_CShuffle_Tag;
template <ck::index_t NDimSpatial,
typename OutLayout_,
typename WeiLayout_,
typename DsLayout_,
typename InLayout_,
typename OutDataType_,
typename WeiDataType_,
typename AccDataType_,
typename OutComputeType_,
typename DsDataType_,
typename InDataType_,
typename OutElementwiseOperation_,
typename WeiElementwiseOperation_,
typename InElementwiseOperation_,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization,
bool do_pad_gemm_m,
bool do_pad_gemm_n,
ck::index_t num_gemm_k_prefetch_stages,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1_,
typename ABlockTransferThreadClusterArrangeOrder_,
typename ABlockTransferSrcAccessOrder_,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1_,
typename BBlockTransferThreadClusterArrangeOrder_,
typename BBlockTransferSrcAccessOrder_,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsAddExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
ck::index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
ck::LoopScheduler LoopSched,
typename ComputeTypeA_,
typename ComputeTypeB_,
ck::index_t max_transpose_transfer_src_scalar_per_vector,
ck::index_t max_transpose_transfer_dst_scalar_per_vector>
struct InstanceTraits<
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<
NDimSpatial,
OutLayout_,
WeiLayout_,
DsLayout_,
InLayout_,
OutDataType_,
WeiDataType_,
AccDataType_,
OutComputeType_,
DsDataType_,
InDataType_,
OutElementwiseOperation_,
WeiElementwiseOperation_,
InElementwiseOperation_,
ConvBackwardDataSpecialization,
do_pad_gemm_m,
do_pad_gemm_n,
num_gemm_k_prefetch_stages,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1_,
ABlockTransferThreadClusterArrangeOrder_,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1_,
BBlockTransferThreadClusterArrangeOrder_,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
CBlockTransferScalarPerVector_NWaveNPerXdl,
LoopSched,
ComputeTypeA_,
ComputeTypeB_,
max_transpose_transfer_src_scalar_per_vector,
max_transpose_transfer_dst_scalar_per_vector>>
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle";
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvBwdData_multiple_d_Xdl_CShuffle_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
using OutLayout = OutLayout_;
using DsLayout = DsLayout_;
using InDataType = InDataType_;
using WeiDataType = WeiDataType_;
using OutDataType = OutDataType_;
using AccDataType = AccDataType_;
using DsDataType = DsDataType_;
using InElementwiseOperation = InElementwiseOperation_;
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBwdDataSpecialization = ConvBackwardDataSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
static constexpr ck::index_t kAK1 = AK1;
static constexpr ck::index_t kBK1 = BK1;
static constexpr ck::index_t kMPerXDL = MPerXDL;
static constexpr ck::index_t kNPerXDL = NPerXDL;
static constexpr ck::index_t kMXdlPerWave = MXdlPerWave;
static constexpr ck::index_t kNXdlPerWave = NXdlPerWave;
static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
CBlockTransferScalarPerVector_NWaveNPerXdl;
static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector =
max_transpose_transfer_src_scalar_per_vector;
static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector =
max_transpose_transfer_dst_scalar_per_vector;
static constexpr bool kDoPadGemmM = do_pad_gemm_m;
static constexpr bool kDoPadGemmN = do_pad_gemm_n;
static constexpr int kNumGemmKPrefetchStage = num_gemm_k_prefetch_stages;
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_AK1;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_BK1;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
using ComputeTypeA = ComputeTypeA_;
using ComputeTypeB = ComputeTypeB_;
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << kTensorOpName;
// Template parameters in exact order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<OutLayout>(); // 2. OutLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<InLayout>(); // 5. InLayout
oss << "," << detail::type_name<OutDataType>(); // 6. OutDataType
oss << "," << detail::type_name<WeiDataType>(); // 7. WeiDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::tuple_name<DsDataType>(); // 9. DsDataType
oss << "," << detail::type_name<InDataType>(); // 10. InDataType
oss << ","
<< detail::elementwise_op_name<OutElementwiseOperation>(); // 11.
// OutElementwiseOperation
oss << ","
<< detail::elementwise_op_name<WeiElementwiseOperation>(); // 12.
// WeiElementwiseOperation
oss << ","
<< detail::elementwise_op_name<InElementwiseOperation>(); // 13. InElementwiseOperation
oss << ","
<< detail::conv_bwd_data_spec_name(
kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization
oss << "," << kDoPadGemmM;
oss << "," << kDoPadGemmN;
oss << "," << kNumGemmKPrefetchStage;
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kK0PerBlock; // 18. K0PerBlock
oss << "," << kAK1; // 19. AK1
oss << "," << kBK1; // 19. BK1
oss << "," << kMPerXDL; // 20. MPerXDL
oss << "," << kNPerXDL; // 21. NPerXDL
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 24.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
oss << "," << kABlockTransferSrcVectorDim; // 27.
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 31.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
oss << "," << kBBlockTransferSrcVectorDim; // 34.
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39.
oss << ","
<< detail::sequence_name<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40.
oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 42.
oss << "," << kNumGemmKPrefetchStage; // 41.
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 43. LoopSched
oss << "," << detail::type_name<ComputeTypeA>(); // 44.
oss << "," << detail::type_name<ComputeTypeB>(); // 45.
oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 46.
oss << "," << kMaxTransposeTransferDstScalarPerVector; // 47.
oss << ">";
return oss.str();
}
};
} // namespace ck_tile::reflect

View File

@@ -18,6 +18,7 @@
#include <string_view>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
@@ -138,6 +139,18 @@ constexpr std::string_view conv_bwd_weight_spec_name(
}
}
// Convert ConvolutionBackwardDataSpecialization enum to string
constexpr std::string_view
conv_bwd_data_spec_name(ck::tensor_operation::device::ConvolutionBackwardDataSpecialization spec)
{
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
switch(spec)
{
case Default: return "Default";
case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
}
}
// Convert GemmSpecialization enum to string
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
{