mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4582 (commit 990a00d)
[CK_Builder] added bwd data kernels to builder factory (#4582) This PR adds bwd data wmma and xdl kernels to the ck builder, their instance and conv traits as well as tests for the above.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c8a8449eec
commit
5e06874aae
@@ -47,11 +47,17 @@ concept BlockGemmPipelineDescriptor = requires(T t) {
|
||||
// Concept for parameters that describe a gridwise WMMA GEMM problem.
|
||||
template <typename T>
|
||||
concept GridwiseWmmaGemmDescriptor = requires(T t) {
|
||||
{ t.k1 } -> SizeType;
|
||||
{ t.m_per_wmma } -> SizeType;
|
||||
{ t.n_per_wmma } -> SizeType;
|
||||
{ t.m_wmma_per_wave } -> SizeType;
|
||||
{ t.n_wmma_per_wave } -> SizeType;
|
||||
(
|
||||
requires { { T::k1 } -> SizeType; } ||
|
||||
(requires { { T::ak1 } -> SizeType; } &&
|
||||
requires { { T::bk1 } -> SizeType; })
|
||||
) &&
|
||||
requires {
|
||||
{ T::m_per_wmma } -> SizeType;
|
||||
{ T::n_per_wmma } -> SizeType;
|
||||
{ T::m_wmma_per_wave } -> SizeType;
|
||||
{ T::n_wmma_per_wave } -> SizeType;
|
||||
};
|
||||
};
|
||||
|
||||
// Concept for vectorized data transfer for convolution input tensors.
|
||||
@@ -187,6 +193,14 @@ concept GridwiseBwdXdlGemmDescriptor = requires(T t) {
|
||||
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept GridwiseBwdDataXdlGemmDescriptor = requires(T t) {
|
||||
{ t.ak1 } -> SizeType;
|
||||
{ t.bk1 } -> SizeType;
|
||||
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseFwdXdlGemm = requires(T t) {
|
||||
@@ -199,6 +213,12 @@ concept SpecifiesGridwiseBwdXdlGemm = requires(T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseBwdDataXdlGemm = requires(T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseBwdDataXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise WMMA GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseWmmaGemm = requires(T t) {
|
||||
@@ -292,6 +312,11 @@ concept SpecifiesBwdWeightConvSpecialization = requires {
|
||||
{ T::bwd_weight_specialization } -> std::convertible_to<ConvSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesBwdDataConvSpecialization = requires {
|
||||
{ T::bwd_data_specialization } -> std::convertible_to<ConvSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesGemmSpecialization = requires {
|
||||
{ T::gemm_specialization } -> std::convertible_to<GemmSpecialization>;
|
||||
|
||||
@@ -28,24 +28,29 @@ concept FwdXdlAlgorithmBase =
|
||||
|
||||
template <typename T>
|
||||
concept BwdXdlAlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters4D<T> &&
|
||||
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
|
||||
(SpecifiesTileTransferParameters4D<T> || SpecifiesTileTransferParameters3D<T>) &&
|
||||
(SpecifiesGridwiseBwdXdlGemm<T> || SpecifiesGridwiseBwdDataXdlGemm<T>) &&
|
||||
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>);
|
||||
|
||||
template <typename T>
|
||||
concept BwdXdlV3AlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
|
||||
(SpecifiesGridwiseBwdXdlGemm<T> || SpecifiesGridwiseBwdDataXdlGemm<T>) &&
|
||||
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>) &&
|
||||
SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdWmmaAlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
|
||||
SpecifiesGridwiseWmmaGemm<T> &&
|
||||
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>);
|
||||
|
||||
template <typename T>
|
||||
concept BwdWmmaV3AlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> &&
|
||||
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>) &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
|
||||
// Reference algorithm concept
|
||||
@@ -107,6 +112,9 @@ concept BwdWmmaAlgorithm =
|
||||
BwdWmmaAlgorithmBase<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T> &&
|
||||
SpecifiesGridwiseGemmPipeline<T> && SpecifiesGenericInstance<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdMultiDWmmaAlgorithm = BwdWmmaAlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
|
||||
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdDataMultipleD_wmma_CShuffle_v3 instance
|
||||
// of a grouped bwd Data convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardData<SIGNATURE>
|
||||
struct ConvBwdDataMultiDWmmaV3Factory
|
||||
{
|
||||
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdDataConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The backward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::OutLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Types::OutDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::InDataType,
|
||||
typename Ops::OutElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::InElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
ALGORITHM.DoPadGemmM,
|
||||
ALGORITHM.DoPadGemmN,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
ck::Sequence<C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector>,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,107 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdDataMultipleD_wmma_CShuffle instance
|
||||
// of a grouped bwd Data convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardData<SIGNATURE>
|
||||
struct ConvBwdDataMultiDWmmaFactory
|
||||
{
|
||||
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdDataConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
|
||||
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The backward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::OutLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Types::OutDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::InDataType,
|
||||
typename Ops::OutElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::InElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
LOOP_SCHEDULER,
|
||||
GRIDWISE_GEMM_PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,113 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_V1 instance
|
||||
// of a grouped bwd Data convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardData<SIGNATURE>
|
||||
struct ConvBwdDataMultiDXdlFactory
|
||||
{
|
||||
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdDataConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The backward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::OutLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Types::OutDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::InDataType,
|
||||
typename Ops::OutElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::InElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
ALGORITHM.DoPadGemmM,
|
||||
ALGORITHM.DoPadGemmN,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
LOOP_SCHEDULER,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -77,6 +77,9 @@
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_cshuffle_v3_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
@@ -148,13 +151,32 @@ constexpr auto make_conv_instance()
|
||||
"WMMA, DL (NHWC layout), or Large Tensor variant.");
|
||||
}
|
||||
}
|
||||
// Backward data direction (will expand with more algorithms in the future)
|
||||
// Backward data direction
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
static_assert(false,
|
||||
"Backward data convolution: Only reference and tile algorithms supported "
|
||||
"currently. "
|
||||
"Optimized kernels (XDL, WMMA, etc.) not yet implemented.");
|
||||
if constexpr(BwdMultiDXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdDataMultiDXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdMultiDWmmaV3Algorithm<AlgoType>)
|
||||
{
|
||||
return
|
||||
typename ConvBwdDataMultiDWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdMultiDWmmaAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdDataMultiDWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"No suitable backward data convolution kernel factory found for the provided "
|
||||
"ALGORITHM. "
|
||||
"The ALGORITHM must satisfy requirements for one of: Reference, XDL multiple d, "
|
||||
"Wmma multiple d, "
|
||||
"or WMMA multiple d v3.");
|
||||
}
|
||||
}
|
||||
// Backward weight direction (will expand with more algorithms in the future)
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
@@ -186,4 +187,24 @@ SetBwdWeightConvSpecialization()
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
|
||||
SetBwdDataConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.bwd_data_specialization;
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvSpecialization::FILTER_1X1_PAD0:
|
||||
throw "FILTER_1x1_PAD0 is not supported for backward data convolution.";
|
||||
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvSpecialization::ODD_C:
|
||||
throw "FILTER ODD_C is not supported for backward data convolution.";
|
||||
case ConvSpecialization::FILTER_3x3:
|
||||
throw "FILTER_3x3 is not supported for backward data convolution.";
|
||||
default: throw "Unsupported ConvSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -90,6 +90,10 @@ class ConvDescription : public Description
|
||||
2, "Gemm padding: ", traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT));
|
||||
else
|
||||
f.writeLine(2, "Struct does not contain optional gemm_padding argument");
|
||||
if(traits_.do_pad_gemm_m)
|
||||
f.writeLine(2, "Do Padd Gemm M: ", traits_.do_pad_gemm_m.value_or(false));
|
||||
if(traits_.do_pad_gemm_n)
|
||||
f.writeLine(2, "Do Padd Gemm N: ", traits_.do_pad_gemm_n.value_or(false));
|
||||
f.writeLine(2, "Convolution specialization: ", traits_.conv_specialization);
|
||||
// Pipeline section
|
||||
f.writeLine(2, "Pipeline version: ", traits_.pipeline_version);
|
||||
@@ -103,7 +107,7 @@ class ConvDescription : public Description
|
||||
traits_.warp_gemm.n_iter);
|
||||
|
||||
// Memory Access section
|
||||
f.writeLast(2, "Memory access:");
|
||||
f.writeLine(2, "Memory access:");
|
||||
|
||||
f.writeLine(3, "A Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
@@ -196,7 +200,7 @@ class ConvDescription : public Description
|
||||
traits_.c_tile_transfer.thread_cluster_dims[2],
|
||||
"×",
|
||||
traits_.c_tile_transfer.thread_cluster_dims[3]);
|
||||
f.writeLine(4,
|
||||
f.writeLast(4,
|
||||
"Vector access (GMEM write) instruction size: ",
|
||||
traits_.c_tile_transfer.scalar_per_vector);
|
||||
if(traits_.num_gemm_k_prefetch_stage)
|
||||
@@ -215,14 +219,14 @@ class ConvDescription : public Description
|
||||
f.writeLine(2,
|
||||
"Struct does not contain optional "
|
||||
"max_transpose_transfer_src_scalar_per_vector parameter");
|
||||
if(traits_.max_transpose_dst_scalar_per_vector)
|
||||
if(traits_.max_transpose_transfer_dst_scalar_per_vector)
|
||||
f.writeLine(2,
|
||||
"Max Transpose dst scalar per vector: ",
|
||||
traits_.max_transpose_dst_scalar_per_vector.value_or(0));
|
||||
traits_.max_transpose_transfer_dst_scalar_per_vector.value_or(0));
|
||||
else
|
||||
f.writeLine(
|
||||
2,
|
||||
"Struct does not contain optional max_transpose_dst_scalar_per_vector parameter");
|
||||
f.writeLine(2,
|
||||
"Struct does not contain optional "
|
||||
"max_transpose_transfer_dst_scalar_per_vector parameter");
|
||||
if(traits_.num_groups_to_merge)
|
||||
f.writeLast(2, "Num groups to merge: ", traits_.num_groups_to_merge.value_or(0));
|
||||
else
|
||||
|
||||
@@ -108,8 +108,10 @@ struct ConvTraits
|
||||
builder::PipelineScheduler pipeline_scheduler;
|
||||
|
||||
std::optional<int> max_transpose_transfer_src_scalar_per_vector = std::nullopt;
|
||||
std::optional<int> max_transpose_dst_scalar_per_vector = std::nullopt;
|
||||
std::optional<int> max_transpose_transfer_dst_scalar_per_vector = std::nullopt;
|
||||
std::optional<int> num_groups_to_merge = std::nullopt;
|
||||
std::optional<bool> do_pad_gemm_m = std::nullopt;
|
||||
std::optional<bool> do_pad_gemm_n = std::nullopt;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Wmma_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
|
||||
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(
|
||||
InstTraits::kCDEShuffleBlockTransferScalarPerVector_NPerBlock),
|
||||
.num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Wmma_CShuffle_V3_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_V3_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1, InstTraits::kK0PerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1, InstTraits::kK0PerBlock),
|
||||
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(
|
||||
InstTraits::kCDEShuffleBlockTransferScalarPerVector_NPerBlock[0]),
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.max_transpose_transfer_src_scalar_per_vector =
|
||||
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
|
||||
.max_transpose_transfer_dst_scalar_per_vector =
|
||||
InstTraits::kMaxTransposeTransferDstScalarPerVector,
|
||||
.do_pad_gemm_m = InstTraits::kDoPadGemmM,
|
||||
.do_pad_gemm_n = InstTraits::kDoPadGemmN,
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,60 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdData_Xdl_CShuffle_Tag
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvBwdData_multiple_d_Xdl_CShuffle_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = bwd_wei_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::InDataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
|
||||
.a_tile_transfer =
|
||||
conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1, InstTraits::kK0PerBlock),
|
||||
.b_tile_transfer =
|
||||
conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1, InstTraits::kK0PerBlock),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer =
|
||||
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
|
||||
.num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.max_transpose_transfer_src_scalar_per_vector =
|
||||
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
|
||||
.max_transpose_transfer_dst_scalar_per_vector =
|
||||
InstTraits::kMaxTransposeTransferDstScalarPerVector,
|
||||
.do_pad_gemm_m = InstTraits::kDoPadGemmM,
|
||||
.do_pad_gemm_n = InstTraits::kDoPadGemmN,
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -42,8 +42,9 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.max_transpose_transfer_src_scalar_per_vector =
|
||||
InstTraits::kTransposeTransferSrcScalarPerVector,
|
||||
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
|
||||
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
|
||||
.max_transpose_transfer_dst_scalar_per_vector =
|
||||
InstTraits::kTransposeTransferDstScalarPerVector,
|
||||
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -49,8 +49,9 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.max_transpose_transfer_src_scalar_per_vector =
|
||||
InstTraits::kTransposeTransferSrcScalarPerVector,
|
||||
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
|
||||
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
|
||||
.max_transpose_transfer_dst_scalar_per_vector =
|
||||
InstTraits::kTransposeTransferDstScalarPerVector,
|
||||
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -42,7 +42,8 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.max_transpose_transfer_src_scalar_per_vector =
|
||||
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
|
||||
.max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector,
|
||||
.max_transpose_transfer_dst_scalar_per_vector =
|
||||
InstTraits::kMaxTransposeTransferDstScalarPerVector,
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
@@ -49,7 +49,8 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.max_transpose_transfer_src_scalar_per_vector =
|
||||
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
|
||||
.max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector,
|
||||
.max_transpose_transfer_dst_scalar_per_vector =
|
||||
InstTraits::kMaxTransposeTransferDstScalarPerVector,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -796,7 +796,8 @@ constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params()
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer()
|
||||
constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer(
|
||||
ck::index_t CDEBlockTansferScalarPerVector = InstTraits::kCDEBlockTransferScalarPerVector)
|
||||
{
|
||||
return OutputTileTransferInfo{
|
||||
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle,
|
||||
@@ -805,7 +806,7 @@ constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer()
|
||||
InstTraits::kCDEThreadClusterLengths[1],
|
||||
InstTraits::kCDEThreadClusterLengths[2],
|
||||
InstTraits::kCDEThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector};
|
||||
.scalar_per_vector = CDEBlockTansferScalarPerVector};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
|
||||
@@ -18,3 +18,8 @@
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
|
||||
// Bwd data instances
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
|
||||
@@ -0,0 +1,315 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
|
||||
namespace ck::tensor_operation::device {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename OutLayout, // output image
|
||||
typename WeiLayout, // weight
|
||||
typename DsLayout, // bias
|
||||
typename InLayout, // input image
|
||||
typename OutDataType, // output image
|
||||
typename WeiDataType, // weight
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType, // bias
|
||||
typename InDataType, // input image
|
||||
typename OutElementwiseOp, // output image
|
||||
typename WeiElementwiseOp, // weight
|
||||
typename InElementwiseOp, // C, bias, and input image
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
|
||||
ConvBackwardDataSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
ck::LoopScheduler LoopSched,
|
||||
ck::PipelineVersion PipelineVer>
|
||||
struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
/// @brief Tag type for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle device kernel
|
||||
struct DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename OutLayout_, // output image
|
||||
typename WeiLayout_, // weight
|
||||
typename DsLayout_, // bias
|
||||
typename InLayout_, // input image
|
||||
typename OutDataType_, // output image
|
||||
typename WeiDataType_, // weight
|
||||
typename AccDataType_,
|
||||
typename CShuffleDataType_,
|
||||
typename DsDataType_, // bias
|
||||
typename InDataType_, // input image
|
||||
typename OutElementwiseOp_, // output image
|
||||
typename WeiElementwiseOp_, // weight
|
||||
typename InElementwiseOp_, // C, bias, and input image
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
|
||||
ConvBackwardDataSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1_,
|
||||
typename ABlockTransferThreadClusterArrangeOrder_,
|
||||
typename ABlockTransferSrcAccessOrder_,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1_,
|
||||
typename BBlockTransferThreadClusterArrangeOrder_,
|
||||
typename BBlockTransferSrcAccessOrder_,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
ck::LoopScheduler LoopSched,
|
||||
ck::PipelineVersion PipelineVer>
|
||||
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<
|
||||
NDimSpatial,
|
||||
OutLayout_, // output image
|
||||
WeiLayout_, // weight
|
||||
DsLayout_, // bias
|
||||
InLayout_, // input image
|
||||
OutDataType_, // output image
|
||||
WeiDataType_, // weight
|
||||
AccDataType_,
|
||||
CShuffleDataType_,
|
||||
DsDataType_, // bias
|
||||
InDataType_, // input image
|
||||
OutElementwiseOp_, // output image
|
||||
WeiElementwiseOp_, // weight
|
||||
InElementwiseOp_, // C, bias, and input image
|
||||
ConvBackwardDataSpecialization,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1_,
|
||||
ABlockTransferThreadClusterArrangeOrder_,
|
||||
ABlockTransferSrcAccessOrder_,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1_,
|
||||
BBlockTransferThreadClusterArrangeOrder_,
|
||||
BBlockTransferSrcAccessOrder_,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumGemmKPrefetchStage,
|
||||
LoopSched,
|
||||
PipelineVer>>
|
||||
{
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
|
||||
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag;
|
||||
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
|
||||
using InDataType = InDataType_;
|
||||
using WeiDataType = WeiDataType_;
|
||||
using OutDataType = OutDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using DsDataType = DsDataType_;
|
||||
|
||||
using InElementwiseOperation = InElementwiseOp_;
|
||||
using WeiElementwiseOperation = WeiElementwiseOp_;
|
||||
using OutElementwiseOperation = OutElementwiseOp_;
|
||||
|
||||
static constexpr auto kConvBwdDataSpecialization = ConvBackwardDataSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
static constexpr ck::index_t kNPerBlock = NPerBlock;
|
||||
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
|
||||
static constexpr ck::index_t kK1 = K1;
|
||||
static constexpr ck::index_t kMPerWmma = MPerWMMA;
|
||||
static constexpr ck::index_t kNPerWmma = NPerWMMA;
|
||||
static constexpr ck::index_t kMRepeat = MRepeat;
|
||||
static constexpr ck::index_t kNRepeat = NRepeat;
|
||||
static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
|
||||
static constexpr ck::index_t kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle;
|
||||
static constexpr ck::index_t kCDEShuffleBlockTransferScalarPerVector_NPerBlock =
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
|
||||
static constexpr ck::PipelineVersion kPipelineVer = PipelineVer;
|
||||
static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage;
|
||||
|
||||
using ABlockTransferThreadClusterLengths_AK0_M_AK1 =
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsExtraN;
|
||||
|
||||
using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
|
||||
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
|
||||
|
||||
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << kTensorOpName;
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 2. OutLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<InLayout>(); // 5. InLayout
|
||||
oss << "," << detail::type_name<OutDataType>(); // 6. OutDataType
|
||||
oss << "," << detail::type_name<WeiDataType>(); // 7. WeiDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 9. DsDataType
|
||||
oss << "," << detail::type_name<InDataType>(); // 10. InDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<OutElementwiseOperation>(); // 11.
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<WeiElementwiseOperation>(); // 12.
|
||||
// WeiElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<InElementwiseOperation>(); // 13. InElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_data_spec_name(
|
||||
kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization
|
||||
oss << "," << kBlockSize; // 15. BlockSize
|
||||
oss << "," << kMPerBlock; // 16. MPerBlock
|
||||
oss << "," << kNPerBlock; // 17. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 18. K0PerBlock
|
||||
oss << "," << kK1; // 19. ABK1
|
||||
oss << "," << kMPerWmma; // 20. MPerWmma
|
||||
oss << "," << kNPerWmma; // 21. NPerWmma
|
||||
oss << "," << kMRepeat; // 22. MRepeat
|
||||
oss << "," << kNRepeat; // 23. NRepeat
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 24.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 27.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 31.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 34.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
|
||||
oss << "," << kCShuffleMRepeatPerShuffle; // 38.
|
||||
oss << "," << kCShuffleNRepeatPerShuffle; // 39.
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40.
|
||||
oss << "," << kCDEShuffleBlockTransferScalarPerVector_NPerBlock; // 41.
|
||||
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 42. LoopSched
|
||||
oss << "," << kNumGemmKPrefetchStage; // 43.
|
||||
oss << "," << detail::pipeline_version_name(kPipelineVer); // 44.
|
||||
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -0,0 +1,350 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
|
||||
namespace ck::tensor_operation::device {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename OutLayout, // output image
|
||||
typename WeiLayout, // weight
|
||||
typename DsLayout, // bias
|
||||
typename InLayout, // input image
|
||||
typename OutDataType, // output image
|
||||
typename WeiDataType, // weight
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType, // bias
|
||||
typename InDataType, // input image
|
||||
typename OutElementwiseOp, // output image
|
||||
typename WeiElementwiseOp, // weight
|
||||
typename InElementwiseOp, // C, bias, and input image
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
|
||||
ConvBackwardDataSpecialization,
|
||||
bool DoPadGemmM,
|
||||
bool DoPadGemmN,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB,
|
||||
ck::index_t max_transpose_transfer_src_scalar_per_vector,
|
||||
ck::index_t max_transpose_transfer_dst_scalar_per_vector>
|
||||
struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
/// @brief Tag type for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3 device kernel
|
||||
struct DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_V3_Tag;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename OutLayout_, // output image
|
||||
typename WeiLayout_, // weight
|
||||
typename DsLayout_, // bias
|
||||
typename InLayout_, // input image
|
||||
typename OutDataType_, // output image
|
||||
typename WeiDataType_, // weight
|
||||
typename AccDataType_,
|
||||
typename CShuffleDataType_,
|
||||
typename DsDataType_, // bias
|
||||
typename InDataType_, // input image
|
||||
typename OutElementwiseOp_, // output image
|
||||
typename WeiElementwiseOp_, // weight
|
||||
typename InElementwiseOp_, // C, bias, and input image
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
|
||||
ConvBackwardDataSpecialization,
|
||||
bool DoPadGemmM,
|
||||
bool DoPadGemmN,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1_,
|
||||
typename ABlockTransferThreadClusterArrangeOrder_,
|
||||
typename ABlockTransferSrcAccessOrder_,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1_,
|
||||
typename BBlockTransferThreadClusterArrangeOrder_,
|
||||
typename BBlockTransferSrcAccessOrder_,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
|
||||
typename CDEShuffleBlockTransferScalarPerVector_NPerBlock_,
|
||||
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA_,
|
||||
typename ComputeTypeB_,
|
||||
ck::index_t max_transpose_transfer_src_scalar_per_vector,
|
||||
ck::index_t max_transpose_transfer_dst_scalar_per_vector>
|
||||
struct InstanceTraits<
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3<
|
||||
NDimSpatial,
|
||||
OutLayout_, // output image
|
||||
WeiLayout_, // weight
|
||||
DsLayout_, // bias
|
||||
InLayout_, // input image
|
||||
OutDataType_, // output image
|
||||
WeiDataType_, // weight
|
||||
AccDataType_,
|
||||
CShuffleDataType_,
|
||||
DsDataType_, // bias
|
||||
InDataType_, // input image
|
||||
OutElementwiseOp_, // output image
|
||||
WeiElementwiseOp_, // weight
|
||||
InElementwiseOp_, // C, bias, and input image
|
||||
ConvBackwardDataSpecialization,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1_,
|
||||
ABlockTransferThreadClusterArrangeOrder_,
|
||||
ABlockTransferSrcAccessOrder_,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1_,
|
||||
BBlockTransferThreadClusterArrangeOrder_,
|
||||
BBlockTransferSrcAccessOrder_,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock_,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA_,
|
||||
ComputeTypeB_,
|
||||
max_transpose_transfer_src_scalar_per_vector,
|
||||
max_transpose_transfer_dst_scalar_per_vector>>
|
||||
{
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3";
|
||||
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_V3_Tag;
|
||||
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
|
||||
using InDataType = InDataType_;
|
||||
using WeiDataType = WeiDataType_;
|
||||
using OutDataType = OutDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using DsDataType = DsDataType_;
|
||||
|
||||
using InElementwiseOperation = InElementwiseOp_;
|
||||
using WeiElementwiseOperation = WeiElementwiseOp_;
|
||||
using OutElementwiseOperation = OutElementwiseOp_;
|
||||
|
||||
static constexpr auto kConvBwdDataSpecialization = ConvBackwardDataSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
static constexpr ck::index_t kNPerBlock = NPerBlock;
|
||||
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
|
||||
static constexpr ck::index_t kAK1 = AK1;
|
||||
static constexpr ck::index_t kBK1 = BK1;
|
||||
static constexpr ck::index_t kMPerWmma = MPerWMMA;
|
||||
static constexpr ck::index_t kNPerWmma = NPerWMMA;
|
||||
static constexpr ck::index_t kMRepeat = MRepeat;
|
||||
static constexpr ck::index_t kNRepeat = NRepeat;
|
||||
static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
|
||||
static constexpr ck::index_t kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle;
|
||||
static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector =
|
||||
max_transpose_transfer_src_scalar_per_vector;
|
||||
static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector =
|
||||
max_transpose_transfer_dst_scalar_per_vector;
|
||||
static constexpr bool kDoPadGemmM = DoPadGemmM;
|
||||
static constexpr bool kDoPadGemmN = DoPadGemmN;
|
||||
using CDEShuffleBlockTransferScalarPerVector_NPerBlock =
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock_;
|
||||
|
||||
static constexpr auto kCDEShuffleBlockTransferScalarPerVector_NPerBlock =
|
||||
detail::SequenceToArray<CDEShuffleBlockTransferScalarPerVector_NPerBlock>::value;
|
||||
|
||||
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
|
||||
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
|
||||
|
||||
using ABlockTransferThreadClusterLengths_AK0_M_AK1 =
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsExtraN;
|
||||
|
||||
using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
|
||||
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
|
||||
using ComputeTypeA = ComputeTypeA_;
|
||||
using ComputeTypeB = ComputeTypeB_;
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << kTensorOpName;
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 2. OutLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<InLayout>(); // 5. InLayout
|
||||
oss << "," << detail::type_name<OutDataType>(); // 6. OutDataType
|
||||
oss << "," << detail::type_name<WeiDataType>(); // 7. WeiDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 9. DsDataType
|
||||
oss << "," << detail::type_name<InDataType>(); // 10. InDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<OutElementwiseOperation>(); // 11.
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<WeiElementwiseOperation>(); // 12.
|
||||
// WeiElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<InElementwiseOperation>(); // 13. InElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_data_spec_name(
|
||||
kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization
|
||||
oss << "," << kDoPadGemmM;
|
||||
oss << "," << kDoPadGemmN;
|
||||
oss << "," << kBlockSize; // 15. BlockSize
|
||||
oss << "," << kMPerBlock; // 16. MPerBlock
|
||||
oss << "," << kNPerBlock; // 17. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 18. K0PerBlock
|
||||
oss << "," << kAK1; // 19. ABK1
|
||||
oss << "," << kBK1; // 19. ABK1
|
||||
oss << "," << kMPerWmma; // 20. MPerWmma
|
||||
oss << "," << kNPerWmma; // 21. NPerWmma
|
||||
oss << "," << kMRepeat; // 22. MRepeat
|
||||
oss << "," << kNRepeat; // 23. NRepeat
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 24.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 27.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 31.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 34.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
|
||||
oss << "," << kCShuffleMRepeatPerShuffle; // 38.
|
||||
oss << "," << kCShuffleNRepeatPerShuffle; // 39.
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40.
|
||||
oss << "," << kCDEShuffleBlockTransferScalarPerVector_NPerBlock[0]; // 41.
|
||||
oss << "," << detail::pipeline_scheduler_name(kBlkGemmPipeSched); // 43.
|
||||
oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 44.
|
||||
oss << "," << detail::type_name<ComputeTypeA>(); // 45.
|
||||
oss << "," << detail::type_name<ComputeTypeB>(); // 46.
|
||||
oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 47.
|
||||
oss << "," << kMaxTransposeTransferDstScalarPerVector; // 48.
|
||||
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -0,0 +1,345 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
|
||||
namespace ck::tensor_operation::device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename OutLayout,
|
||||
typename WeiLayout,
|
||||
typename DsLayout,
|
||||
typename InLayout,
|
||||
typename OutDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutComputeType,
|
||||
typename DsDataType,
|
||||
typename InDataType,
|
||||
typename OutElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename InElementwiseOperation,
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
|
||||
ConvBackwardDataSpecialization,
|
||||
bool do_pad_gemm_m,
|
||||
bool do_pad_gemm_n,
|
||||
ck::index_t num_gemm_k_prefetch_stages,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
ck::index_t BBlockLdsAddExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
ck::LoopScheduler LoopSched,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB,
|
||||
ck::index_t max_transpose_transfer_src_scalar_per_vector,
|
||||
ck::index_t max_transpose_transfer_dst_scalar_per_vector>
|
||||
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
/// @brief Tag type for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle device kernel
|
||||
struct DeviceGroupedConvBwdData_multiple_d_Xdl_CShuffle_Tag;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename OutLayout_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename InLayout_,
|
||||
typename OutDataType_,
|
||||
typename WeiDataType_,
|
||||
typename AccDataType_,
|
||||
typename OutComputeType_,
|
||||
typename DsDataType_,
|
||||
typename InDataType_,
|
||||
typename OutElementwiseOperation_,
|
||||
typename WeiElementwiseOperation_,
|
||||
typename InElementwiseOperation_,
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
|
||||
ConvBackwardDataSpecialization,
|
||||
bool do_pad_gemm_m,
|
||||
bool do_pad_gemm_n,
|
||||
ck::index_t num_gemm_k_prefetch_stages,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1_,
|
||||
typename ABlockTransferThreadClusterArrangeOrder_,
|
||||
typename ABlockTransferSrcAccessOrder_,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1_,
|
||||
typename BBlockTransferThreadClusterArrangeOrder_,
|
||||
typename BBlockTransferSrcAccessOrder_,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
|
||||
ck::index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
ck::LoopScheduler LoopSched,
|
||||
typename ComputeTypeA_,
|
||||
typename ComputeTypeB_,
|
||||
ck::index_t max_transpose_transfer_src_scalar_per_vector,
|
||||
ck::index_t max_transpose_transfer_dst_scalar_per_vector>
|
||||
struct InstanceTraits<
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<
|
||||
NDimSpatial,
|
||||
OutLayout_,
|
||||
WeiLayout_,
|
||||
DsLayout_,
|
||||
InLayout_,
|
||||
OutDataType_,
|
||||
WeiDataType_,
|
||||
AccDataType_,
|
||||
OutComputeType_,
|
||||
DsDataType_,
|
||||
InDataType_,
|
||||
OutElementwiseOperation_,
|
||||
WeiElementwiseOperation_,
|
||||
InElementwiseOperation_,
|
||||
ConvBackwardDataSpecialization,
|
||||
do_pad_gemm_m,
|
||||
do_pad_gemm_n,
|
||||
num_gemm_k_prefetch_stages,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1_,
|
||||
ABlockTransferThreadClusterArrangeOrder_,
|
||||
ABlockTransferSrcAccessOrder_,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1_,
|
||||
BBlockTransferThreadClusterArrangeOrder_,
|
||||
BBlockTransferSrcAccessOrder_,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
LoopSched,
|
||||
ComputeTypeA_,
|
||||
ComputeTypeB_,
|
||||
max_transpose_transfer_src_scalar_per_vector,
|
||||
max_transpose_transfer_dst_scalar_per_vector>>
|
||||
{
|
||||
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle";
|
||||
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvBwdData_multiple_d_Xdl_CShuffle_Tag;
|
||||
|
||||
static constexpr ck::index_t kSpatialDim = NDimSpatial;
|
||||
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
|
||||
using InDataType = InDataType_;
|
||||
using WeiDataType = WeiDataType_;
|
||||
using OutDataType = OutDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using DsDataType = DsDataType_;
|
||||
|
||||
using InElementwiseOperation = InElementwiseOperation_;
|
||||
using WeiElementwiseOperation = WeiElementwiseOperation_;
|
||||
using OutElementwiseOperation = OutElementwiseOperation_;
|
||||
|
||||
static constexpr auto kConvBwdDataSpecialization = ConvBackwardDataSpecialization;
|
||||
|
||||
static constexpr ck::index_t kBlockSize = BlockSize;
|
||||
static constexpr ck::index_t kMPerBlock = MPerBlock;
|
||||
static constexpr ck::index_t kNPerBlock = NPerBlock;
|
||||
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
|
||||
static constexpr ck::index_t kAK1 = AK1;
|
||||
static constexpr ck::index_t kBK1 = BK1;
|
||||
static constexpr ck::index_t kMPerXDL = MPerXDL;
|
||||
static constexpr ck::index_t kNPerXDL = NPerXDL;
|
||||
static constexpr ck::index_t kMXdlPerWave = MXdlPerWave;
|
||||
static constexpr ck::index_t kNXdlPerWave = NXdlPerWave;
|
||||
static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
|
||||
static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
|
||||
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl;
|
||||
static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector =
|
||||
max_transpose_transfer_src_scalar_per_vector;
|
||||
static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector =
|
||||
max_transpose_transfer_dst_scalar_per_vector;
|
||||
static constexpr bool kDoPadGemmM = do_pad_gemm_m;
|
||||
static constexpr bool kDoPadGemmN = do_pad_gemm_n;
|
||||
static constexpr int kNumGemmKPrefetchStage = num_gemm_k_prefetch_stages;
|
||||
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
|
||||
ABlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
|
||||
ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
|
||||
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
|
||||
BBlockTransferSrcScalarPerVector;
|
||||
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
|
||||
BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
|
||||
|
||||
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
|
||||
|
||||
using ComputeTypeA = ComputeTypeA_;
|
||||
using ComputeTypeB = ComputeTypeB_;
|
||||
|
||||
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
|
||||
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << kTensorOpName;
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 2. OutLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<InLayout>(); // 5. InLayout
|
||||
oss << "," << detail::type_name<OutDataType>(); // 6. OutDataType
|
||||
oss << "," << detail::type_name<WeiDataType>(); // 7. WeiDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 9. DsDataType
|
||||
oss << "," << detail::type_name<InDataType>(); // 10. InDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<OutElementwiseOperation>(); // 11.
|
||||
// OutElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<WeiElementwiseOperation>(); // 12.
|
||||
// WeiElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<InElementwiseOperation>(); // 13. InElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_bwd_data_spec_name(
|
||||
kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization
|
||||
oss << "," << kDoPadGemmM;
|
||||
oss << "," << kDoPadGemmN;
|
||||
oss << "," << kNumGemmKPrefetchStage;
|
||||
oss << "," << kBlockSize; // 15. BlockSize
|
||||
oss << "," << kMPerBlock; // 16. MPerBlock
|
||||
oss << "," << kNPerBlock; // 17. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 18. K0PerBlock
|
||||
oss << "," << kAK1; // 19. AK1
|
||||
oss << "," << kBK1; // 19. BK1
|
||||
oss << "," << kMPerXDL; // 20. MPerXDL
|
||||
oss << "," << kNPerXDL; // 21. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 24.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 27.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 31.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 34.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38.
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39.
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40.
|
||||
oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 42.
|
||||
oss << "," << kNumGemmKPrefetchStage; // 41.
|
||||
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 43. LoopSched
|
||||
oss << "," << detail::type_name<ComputeTypeA>(); // 44.
|
||||
oss << "," << detail::type_name<ComputeTypeB>(); // 45.
|
||||
oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 46.
|
||||
oss << "," << kMaxTransposeTransferDstScalarPerVector; // 47.
|
||||
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -18,6 +18,7 @@
|
||||
#include <string_view>
|
||||
#include <type_traits>
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
@@ -138,6 +139,18 @@ constexpr std::string_view conv_bwd_weight_spec_name(
|
||||
}
|
||||
}
|
||||
|
||||
// Convert ConvolutionBackwardDataSpecialization enum to string
|
||||
constexpr std::string_view
|
||||
conv_bwd_data_spec_name(ck::tensor_operation::device::ConvolutionBackwardDataSpecialization spec)
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case Default: return "Default";
|
||||
case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert GemmSpecialization enum to string
|
||||
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
|
||||
{
|
||||
|
||||
@@ -197,6 +197,9 @@ target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_bwd_data_instances
|
||||
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp
|
||||
conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_data_multi_d_wmma_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_data_multi_d_wmma_cshuffle_v3.cpp
|
||||
)
|
||||
target_link_libraries(test_ckb_build_bwd_data_instances PRIVATE utility)
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "gmock/gmock.h"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_DATA,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
|
||||
.with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdData_2DFp16_MultiD_Wmma_CShuffle_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWK,GKYXC,EmptyTuple,GNHWC",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16"}); // check compute types
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "gmock/gmock.h"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_DATA,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParamsABK1_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
|
||||
.with_gemm_pad_params(0, 0)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
|
||||
.with_transpose_params(2, 2);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdData_2DFp16_MultiD_Wmma_CShuffle_V3_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWK,GKYXC,EmptyTuple,GNHWC",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16"}); // check compute types
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "gmock/gmock.h"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_DATA,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(cku::BwdDataGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::Transfer_4x64x1)
|
||||
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
|
||||
.with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_gemm_pad_params(0, 0)
|
||||
.with_transpose_params(2, 2);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdData_2DFp16_MultiD_Xdl_CShuffle_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWK,GKYXC,EmptyTuple,GNHWC",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16"}); // check compute types
|
||||
}
|
||||
@@ -19,6 +19,9 @@
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -35,7 +38,390 @@ class ConvTraitsTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3
|
||||
// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleWmmaTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::half_t, // OutDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
float, // OutComputeType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Default, // ConvBackwardDataSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerWMMA
|
||||
32, // NPerWMMA
|
||||
4, // MRepeat
|
||||
4, // NRepeat
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerWavePerShuffle
|
||||
1, // CShuffleNRepeatPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
2, // NumGemmKPrefetchStage
|
||||
ck::LoopScheduler::Default, // BlkGemmPipeSched
|
||||
ck::PipelineVersion::v1>; // PipelineVerison
|
||||
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(traits.spatial_dim, 2);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(traits.data_type, DataType::FP32);
|
||||
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(traits.thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(traits.tile_dims.m, 128);
|
||||
EXPECT_EQ(traits.tile_dims.n, 128);
|
||||
EXPECT_EQ(traits.tile_dims.k, 16);
|
||||
|
||||
// Verify A tile transfer info
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify B tile transfer info
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify warp GEMM params
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
|
||||
|
||||
// Verify output tile transfer info
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.num_gemm_k_prefetch_stage, 2);
|
||||
|
||||
// Verify pipeline configuration
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3
|
||||
TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleWmmaV3TraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::half_t, // OutDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
float, // OutComputeType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Default, // ConvBackwardDataSpecialization
|
||||
false, // DoPadGemmM
|
||||
false, // DoPadGemmN
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerWMMA
|
||||
32, // NPerWMMA
|
||||
4, // MRepeat
|
||||
4, // NRepeat
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerWavePerShuffle
|
||||
1, // CShuffleNRepeatPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
ck::Sequence<8, 8, 8>, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
1, // MaxTransposeTransferSrcScalarPerVector
|
||||
1>; // MaxTransposeTransferDstScalarPerVector
|
||||
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(traits.spatial_dim, 2);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(traits.data_type, DataType::FP32);
|
||||
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(traits.thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(traits.tile_dims.m, 128);
|
||||
EXPECT_EQ(traits.tile_dims.n, 128);
|
||||
EXPECT_EQ(traits.tile_dims.k, 16);
|
||||
|
||||
// Verify A tile transfer info
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
|
||||
EXPECT_FALSE(traits.do_pad_gemm_n.value());
|
||||
EXPECT_FALSE(traits.do_pad_gemm_m.value());
|
||||
|
||||
// Verify B tile transfer info
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify warp GEMM params
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
|
||||
|
||||
// Verify output tile transfer info
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
|
||||
// Verify pipeline configuration
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle
|
||||
TEST_F(ConvTraitsTest, ConvBwdDataMultipleDCshuffleXDLTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::half_t, // OutDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
float, // OutComputeType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Default, // ConvBackwardDataSpecialization
|
||||
false, // DoPadGemmM
|
||||
false, // DoPadGemmN
|
||||
1, // num_gemm_k_prefetch_stage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
ck::LoopScheduler::Default, // BlkGemmPipeSched
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
1, // MaxTransposeTransferSrcScalarPerVector
|
||||
1>; // MaxTransposeTransferDstScalarPerVector
|
||||
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(traits.spatial_dim, 2);
|
||||
EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_DATA);
|
||||
EXPECT_THAT(traits.layout,
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(traits.data_type, DataType::FP32);
|
||||
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
EXPECT_EQ(traits.num_gemm_k_prefetch_stage, 1);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(traits.thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(traits.tile_dims.m, 128);
|
||||
EXPECT_EQ(traits.tile_dims.n, 128);
|
||||
EXPECT_EQ(traits.tile_dims.k, 16);
|
||||
|
||||
// Verify A tile transfer info
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
|
||||
EXPECT_FALSE(traits.do_pad_gemm_n.value());
|
||||
EXPECT_FALSE(traits.do_pad_gemm_m.value());
|
||||
|
||||
// Verify B tile transfer info
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify warp GEMM params
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
|
||||
|
||||
// Verify output tile transfer info
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
@@ -270,6 +656,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaV3TraitsExtraction)
|
||||
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
|
||||
// Verify pipeline configuration
|
||||
}
|
||||
|
||||
@@ -516,6 +905,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageWmmaCshuffleTraitsExtraction)
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffleV3
|
||||
@@ -640,6 +1032,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageXdlCshuffleTraitsExtraction)
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
@@ -1001,6 +1396,9 @@ TEST_F(ConvTraitsTest, ConvBwdWeightXdlCshuffleTraitsExtraction)
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
|
||||
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
|
||||
|
||||
EXPECT_EQ(traits.max_transpose_transfer_src_scalar_per_vector, 1);
|
||||
EXPECT_EQ(traits.max_transpose_transfer_dst_scalar_per_vector, 1);
|
||||
}
|
||||
|
||||
// test conv traits device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
|
||||
|
||||
@@ -54,6 +54,13 @@ struct GridwiseBwdXdlGemm
|
||||
};
|
||||
static_assert(ckb::GridwiseBwdXdlGemmDescriptor<GridwiseBwdXdlGemm>);
|
||||
|
||||
struct GridwiseBwdDataXdlGemm
|
||||
{
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
XdlParams xdl_params;
|
||||
};
|
||||
|
||||
// Describe gridwise WMMA GEMM parameters.
|
||||
struct GridwiseWmmaGemm
|
||||
{
|
||||
@@ -64,6 +71,16 @@ struct GridwiseWmmaGemm
|
||||
size_t n_wmma_per_wave = 0;
|
||||
};
|
||||
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
|
||||
struct GridwiseWmmaGemmABK1
|
||||
{
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
size_t m_per_wmma = 0;
|
||||
size_t n_per_wmma = 0;
|
||||
size_t m_wmma_per_wave = 0;
|
||||
size_t n_wmma_per_wave = 0;
|
||||
};
|
||||
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemmABK1>);
|
||||
|
||||
struct BlockGemmPipeline
|
||||
{
|
||||
@@ -209,11 +226,21 @@ struct BwdXdlGemm_
|
||||
GridwiseBwdXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct BwdDataXdlGemm_
|
||||
{
|
||||
GridwiseBwdDataXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct WmmaGemm_
|
||||
{
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct WmmaGemmABK1_
|
||||
{
|
||||
GridwiseWmmaGemmABK1 gridwise_gemm;
|
||||
};
|
||||
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct Transfer_
|
||||
{
|
||||
@@ -231,12 +258,23 @@ struct ConvSpecializationBwdWeight_
|
||||
ConvSpecialization bwd_weight_specialization;
|
||||
};
|
||||
|
||||
struct ConvSpecializationBwdData_
|
||||
{
|
||||
ConvSpecialization bwd_data_specialization;
|
||||
};
|
||||
|
||||
struct Prefetch_
|
||||
{
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
|
||||
struct GemmPad_
|
||||
{
|
||||
size_t DoPadGemmM;
|
||||
size_t DoPadGemmN;
|
||||
};
|
||||
|
||||
struct TransposeParams_
|
||||
{
|
||||
size_t max_transpose_transfer_src_scalar_per_vector{1};
|
||||
@@ -394,10 +432,18 @@ struct ConvAlgorithmTemplate : Components...
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else if constexpr(std::is_base_of_v<BwdDataXdlGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else if constexpr(std::is_base_of_v<WmmaGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else if constexpr(std::is_base_of_v<WmmaGemmABK1_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unrecognized GemmConfig type");
|
||||
@@ -433,6 +479,14 @@ struct ConvAlgorithmTemplate : Components...
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_bwd_data_specialization(ConvSpecialization bwd_spec) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<ConvSpecializationBwdData_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.bwd_data_specialization = bwd_spec;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<Prefetch_, ConvAlgorithmTemplate>);
|
||||
@@ -452,6 +506,15 @@ struct ConvAlgorithmTemplate : Components...
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_gemm_pad_params(size_t doPadGemmN_, size_t doPadGemmM_) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<GemmPad_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.DoPadGemmN = doPadGemmN_;
|
||||
result.DoPadGemmM = doPadGemmM_;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<GemmBatchOptions_, ConvAlgorithmTemplate>);
|
||||
@@ -684,4 +747,35 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 =
|
||||
BlockGemm_,
|
||||
MultipleDSpecialization_>;
|
||||
|
||||
// Bwd Data algorithm types
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
BwdDataXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdData_,
|
||||
MultipleDSpecialization_,
|
||||
Prefetch_,
|
||||
TransposeParams_,
|
||||
GemmPad_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdData_,
|
||||
GridGemm_,
|
||||
MultipleDSpecialization_,
|
||||
Prefetch_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemmABK1_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdData_,
|
||||
BlockGemm_,
|
||||
MultipleDSpecialization_,
|
||||
Prefetch_,
|
||||
TransposeParams_,
|
||||
GemmPad_>;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -282,38 +282,39 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 16×16\n"
|
||||
" │ └─ Number of warp gemm iterations: 8×8\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" ├─ Vector access (GMEM write) instruction size: 2\n"
|
||||
" ├─ Memory access:\n"
|
||||
" │ ├─ A Tile transfer: \n"
|
||||
" │ │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" │ ├─ B Tile transfer: \n"
|
||||
" │ │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" │ └─ C Tile transfer: \n"
|
||||
" │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" │ └─ Vector access (GMEM write) instruction size: 2\n"
|
||||
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
|
||||
"parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_transfer_dst_scalar_per_vector "
|
||||
"parameter\n"
|
||||
" └─ Struct does not contain optional num_groups_to_merge parameter"));
|
||||
}
|
||||
|
||||
// Test printing of optional parameters num_groups_to_merge,
|
||||
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
|
||||
// max_transpose_transfer_src_scalar_per_vector and max_transpose_transfer_dst_scalar_per_vector
|
||||
TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest)
|
||||
{
|
||||
using Instance =
|
||||
@@ -390,29 +391,29 @@ TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest)
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 32×32\n"
|
||||
" │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" ├─ Vector access (GMEM write) instruction size: 8\n"
|
||||
" ├─ Memory access:\n"
|
||||
" │ ├─ A Tile transfer: \n"
|
||||
" │ │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" │ ├─ B Tile transfer: \n"
|
||||
" │ │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" │ └─ C Tile transfer: \n"
|
||||
" │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" │ └─ Vector access (GMEM write) instruction size: 8\n"
|
||||
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
|
||||
" ├─ Max Transpose transfer scr scalar per vector: 1\n"
|
||||
" ├─ Max Transpose dst scalar per vector: 1\n"
|
||||
@@ -494,33 +495,34 @@ TEST(ConvDescriptionTest, BwdWeightWmmaCshuffleV3DescriptionTest)
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 32×32\n"
|
||||
" │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" ├─ Vector access (GMEM write) instruction size: 8\n"
|
||||
" ├─ Memory access:\n"
|
||||
" │ ├─ A Tile transfer: \n"
|
||||
" │ │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" │ ├─ B Tile transfer: \n"
|
||||
" │ │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" │ └─ C Tile transfer: \n"
|
||||
" │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" │ └─ Vector access (GMEM write) instruction size: 8\n"
|
||||
" ├─ Num gemm k prefetch stage: 1\n"
|
||||
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
|
||||
"parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_transfer_dst_scalar_per_vector "
|
||||
"parameter\n"
|
||||
" └─ Struct does not contain optional num_groups_to_merge parameter"));
|
||||
}
|
||||
|
||||
|
||||
@@ -249,6 +249,26 @@ constexpr Transfer<> Transfer_4x32x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}};
|
||||
|
||||
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
@@ -283,6 +303,13 @@ constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{
|
||||
constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{
|
||||
.k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
|
||||
|
||||
constexpr GridwiseWmmaGemmABK1 GemmParamsABK1_Wmma_16x16_2x1_per_wave{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_wmma = 16,
|
||||
.n_per_wmma = 16,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1};
|
||||
|
||||
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
|
||||
@@ -85,6 +85,15 @@ inline std::string to_string<ThreadBlock>(ThreadBlock t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseBwdDataXdlGemm>(GridwiseBwdDataXdlGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl
|
||||
<< "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseBwdXdlGemm>(GridwiseBwdXdlGemm t)
|
||||
{
|
||||
@@ -112,6 +121,15 @@ inline std::string to_string<GridwiseWmmaGemm>(GridwiseWmmaGemm t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseWmmaGemmABK1>(GridwiseWmmaGemmABK1 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.ak1 << "," << t.bk1 << "," << t.m_per_wmma << "," << t.n_per_wmma << ","
|
||||
<< t.m_wmma_per_wave << "," << t.n_wmma_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockGemmPipeline>(BlockGemmPipeline t)
|
||||
{
|
||||
@@ -283,12 +301,24 @@ inline std::string to_string<BwdXdlGemm_>(BwdXdlGemm_ t)
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BwdDataXdlGemm_>(BwdDataXdlGemm_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<WmmaGemmABK1_>(WmmaGemmABK1_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <size_t ThreadClusterRank = 3>
|
||||
inline std::string to_string(Transfer_<ThreadClusterRank> t)
|
||||
{
|
||||
@@ -311,6 +341,14 @@ inline std::string to_string<ConvSpecializationBwdWeight_>(ConvSpecializationBwd
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvSpecializationBwdData_>(ConvSpecializationBwdData_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.bwd_data_specialization);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<Prefetch_>(Prefetch_ t)
|
||||
{
|
||||
@@ -495,4 +533,36 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_X
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
|
||||
<< to_string(static_cast<BwdDataXdlGemm_>(t)) << ","
|
||||
<< to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
|
||||
<< to_string(static_cast<WmmaGemmABK1_>(t)) << ","
|
||||
<< to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
Reference in New Issue
Block a user