added bwd data wmma instances to builder

This commit is contained in:
Kevin Abraham
2026-02-12 18:51:37 +00:00
parent eefadf7afd
commit eb6e124e43
13 changed files with 657 additions and 40 deletions

View File

@@ -28,7 +28,8 @@ concept FwdXdlAlgorithmBase =
template <typename T>
concept BwdXdlAlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters4D<T> &&
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
(SpecifiesTileTransferParameters4D<T> || SpecifiesTileTransferParameters3D<T>) &&
(SpecifiesGridwiseBwdXdlGemm<T> || SpecifiesGridwiseBwdDataXdlGemm<T>) &&
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>);
@@ -111,6 +112,9 @@ concept BwdWmmaAlgorithm =
BwdWmmaAlgorithmBase<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T> &&
SpecifiesGridwiseGemmPipeline<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept BwdMultiDWmmaAlgorithm = BwdWmmaAlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
template <typename T>
concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase<T> && SpecifiesMultipleDSupport<T>;

View File

@@ -0,0 +1,86 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdDataMultipleD_wmma_CShuffle instance
// of a grouped bwd Data convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardData<SIGNATURE>
struct ConvBwdDataMultiDWmmaFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdDataConvSpecialization<ALGORITHM>();
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The backward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle < SPATIAL_DIM,
typename Layouts::OutLayout, typename Layouts::WeiLayout, typename Layouts::DsLayout,
typename Layouts::InLayout, typename Types::OutDataType, typename Types::WeiDataType,
typename Types::AccDataType, typename Types::OutComputeType, typename Types::DsDataType,
typename Types::InDataType, typename Ops::OutElementwiseOp,
typename Ops::WeiElementwiseOp, typename Ops::InElementwiseOp, BWD_CONV_SPECIALIZATION,
bool DoPadGemmM, bool DoPadGemmN, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n,
BLOCK.per_block.k, GRIDWISE_GEMM.ak1, GRIDWISE_GEMM.bk1, GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma, GRIDWISE_GEMM.m_wmma_per_wave, GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>, A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector, A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding, to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>, B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector, B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding, C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>, C_BLOCK_TRANSFER.scalar_per_vector,
ALGORITHM.num_gemm_k_prefetch_stages, LOOP_SCHEDULER, GRIDWISE_GEMM_PIPELINE_VERSION,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector >>
;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,107 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdDataMultipleD_wmma_CShuffle instance
// of a grouped bwd Data convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardData<SIGNATURE>
struct ConvBwdDataMultiDWmmaFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdDataConvSpecialization<ALGORITHM>();
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The backward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<
SPATIAL_DIM,
typename Layouts::OutLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::InLayout,
typename Types::OutDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::InDataType,
typename Ops::OutElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::InElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
ALGORITHM.num_gemm_k_prefetch_stages,
LOOP_SCHEDULER,
GRIDWISE_GEMM_PIPELINE_VERSION>;
};
} // namespace ck_tile::builder::factory

View File

@@ -78,6 +78,7 @@
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_data_multi_d_wmma_factory.hpp"
namespace ck_tile::builder::factory {
@@ -152,19 +153,23 @@ constexpr auto make_conv_instance()
// Backward data direction (will expand with more algorithms in the future)
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
// if constexpr(BwdMultiDXdlAlgorithm<AlgoType>)
// {
return typename ConvBwdDataMultiDXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
// }
// else
// {
// static_assert(
// false,
// "No suitable backward data convolution kernel factory found for the provided "
// "ALGORITHM. "
// "The ALGORITHM must satisfy requirements for one of: Reference, Tile, XDL V3,
// XDL, " "WMMA, DL (NHWC layout), or Large Tensor variant.");
// }
if constexpr(BwdMultiDXdlAlgorithm<AlgoType>)
{
return typename ConvBwdDataMultiDXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdMultiDWmmaAlgorithm<AlgoType>)
{
return typename ConvBwdDataMultiDWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else
{
static_assert(
false,
"No suitable backward data convolution kernel factory found for the provided "
"ALGORITHM. "
"The ALGORITHM must satisfy requirements for one of: Reference, Tile, XDL V3, XDL, "
"WMMA, DL (NHWC layout), or Large Tensor variant.");
}
}
// Backward weight direction (will expand with more algorithms in the future)
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)

View File

@@ -0,0 +1,319 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
namespace ck::tensor_operation::device {
template <index_t NDimSpatial,
typename OutLayout, // output image
typename WeiLayout, // weight
typename DsLayout, // bias
typename InLayout, // input image
typename OutDataType, // output image
typename WeiDataType, // weight
typename AccDataType,
typename CShuffleDataType,
typename DsDataType, // bias
typename InDataType, // input image
typename OutElementwiseOp, // output image
typename WeiElementwiseOp, // weight
typename InElementwiseOp, // C, bias, and input image
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage,
ck::LoopScheduler LoopSched,
ck::PipelineVersion PipelineVer>
struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle;
} // namespace ck::tensor_operation::device
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle device kernel
struct DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag
{
};
template <index_t NDimSpatial,
typename OutLayout_, // output image
typename WeiLayout_, // weight
typename DsLayout_, // bias
typename InLayout_, // input image
typename OutDataType_, // output image
typename WeiDataType_, // weight
typename AccDataType_,
typename CShuffleDataType_,
typename DsDataType_, // bias
typename InDataType_, // input image
typename OutElementwiseOp_, // output image
typename WeiElementwiseOp_, // weight
typename InElementwiseOp_, // C, bias, and input image
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1_,
typename ABlockTransferThreadClusterArrangeOrder_,
typename ABlockTransferSrcAccessOrder_,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1_,
typename BBlockTransferThreadClusterArrangeOrder_,
typename BBlockTransferSrcAccessOrder_,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage,
ck::LoopScheduler LoopSched,
ck::PipelineVersion PipelineVer>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<
NDimSpatial,
OutLayout_, // output image
WeiLayout_, // weight
DsLayout_, // bias
InLayout_, // input image
OutDataType_, // output image
WeiDataType_, // weight
AccDataType_,
CShuffleDataType_,
DsDataType_, // bias
InDataType_, // input image
OutElementwiseOp_, // output image
WeiElementwiseOp_, // weight
InElementwiseOp_, // C, bias, and input image
ConvBackwardDataSpecialization,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
MPerWMMA,
NPerWMMA,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1_,
ABlockTransferThreadClusterArrangeOrder_,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1_,
BBlockTransferThreadClusterArrangeOrder_,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
NumGemmKPrefetchStage,
LoopSched,
PipelineVer>>
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvBwdData_multiple_d_Wmma_CShuffle_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
using OutLayout = OutLayout_;
using DsLayout = DsLayout_;
using InDataType = InDataType_;
using WeiDataType = WeiDataType_;
using OutDataType = OutDataType_;
using AccDataType = AccDataType_;
using DsDataType = DsDataType_;
using InElementwiseOperation = InElementwiseOp_;
using WeiElementwiseOperation = WeiElementwiseOp_;
using OutElementwiseOperation = OutElementwiseOp_;
static constexpr auto kConvBwdDataSpecialization = ConvBackwardDataSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
static constexpr ck::index_t kK1 = K1;
static constexpr ck::index_t kMPerWmma = MPerWMMA;
static constexpr ck::index_t kNPerWmma = NPerWMMA;
static constexpr ck::index_t kMRepeat = MRepeat;
static constexpr ck::index_t kNRepeat = NRepeat;
static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
static constexpr ck::index_t kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle;
static constexpr ck::index_t kCDEShuffleBlockTransferScalarPerVector_NPerBlock =
CDEShuffleBlockTransferScalarPerVector_NPerBlock;
static constexpr ck::PipelineVersion kPipelineVer = PipelineVer;
static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage;
using ABlockTransferThreadClusterLengths_AK0_M_AK1 =
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_AK1;
static constexpr bool kABlockLdsExtraM = ABlockLdsExtraM;
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_BK1;
static constexpr bool kBBlockLdsExtraN = BBlockLdsExtraN;
using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
// Template parameters in exact order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<OutLayout>(); // 2. OutLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<InLayout>(); // 5. InLayout
oss << "," << detail::type_name<OutDataType>(); // 6. OutDataType
oss << "," << detail::type_name<WeiDataType>(); // 7. WeiDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::tuple_name<DsDataType>(); // 9. DsDataType
oss << "," << detail::type_name<InDataType>(); // 10. InDataType
oss << ","
<< detail::elementwise_op_name<OutElementwiseOperation>(); // 11.
// OutElementwiseOperation
oss << ","
<< detail::elementwise_op_name<WeiElementwiseOperation>(); // 12.
// WeiElementwiseOperation
oss << ","
<< detail::elementwise_op_name<InElementwiseOperation>(); // 13. InElementwiseOperation
oss << ","
<< detail::conv_bwd_data_spec_name(
kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kK0PerBlock; // 18. K0PerBlock
oss << "," << kK1; // 19. ABK1
oss << "," << kMPerWmma; // 20. MPerWmma
oss << "," << kNPerWmma; // 21. NPerWmma
oss << "," << kMRepeat; // 22. MRepeat
oss << "," << kNRepeat; // 23. NRepeat
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 24.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
oss << "," << kABlockTransferSrcVectorDim; // 27.
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 31.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
oss << "," << kBBlockTransferSrcVectorDim; // 34.
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
oss << "," << kCShuffleMRepeatPerShuffle; // 38.
oss << "," << kCShuffleNRepeatPerShuffle; // 39.
oss << ","
<< detail::sequence_name<
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40.
oss << "," << kCDEShuffleBlockTransferScalarPerVector_NPerBlock; // 41.
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 42. LoopSched
oss << "," << kNumGemmKPrefetchStage; // 43.
oss << "," << detail::pipeline_version_name(kPipelineVer); // 44.
oss << ">";
return oss.str();
}
};
} // namespace reflect
} // namespace ck_tile

View File

@@ -279,36 +279,39 @@ struct InstanceTraits<
// Template parameters in exact order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<OutLayout>(); // 2. OutLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
oss << "," << detail::tuple_name<DsLayout>(); // 5. DsLayout
oss << "," << detail::type_name<InDataType>(); // 6. InDataType
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<InLayout>(); // 5. InLayout
oss << "," << detail::type_name<OutDataType>(); // 6. OutDataType
oss << "," << detail::type_name<WeiDataType>(); // 7. WeiDataType
oss << "," << detail::type_name<OutDataType>(); // 8. OutDataType
oss << "," << detail::type_name<AccDataType>(); // 9. AccDataType
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::tuple_name<DsDataType>(); // 9. DsDataType
oss << "," << detail::type_name<InDataType>(); // 10. InDataType
oss << ","
<< detail::elementwise_op_name<InElementwiseOperation>(); // 11. InElementwiseOperation
<< detail::elementwise_op_name<OutElementwiseOperation>(); // 11.
// OutElementwiseOperation
oss << ","
<< detail::elementwise_op_name<WeiElementwiseOperation>(); // 12.
// WeiElementwiseOperation
oss << ","
<< detail::elementwise_op_name<OutElementwiseOperation>(); // 13.
// OutElementwiseOperation
<< detail::elementwise_op_name<InElementwiseOperation>(); // 13. InElementwiseOperation
oss << ","
<< detail::conv_bwd_data_spec_name(
kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kK0PerBlock; // 18. K0PerBlock
oss << "," << kAK1; // 19. AK1
oss << "," << kBK1; // 19,5. BK1
oss << "," << kMPerXDL; // 20. MPerXDL
oss << "," << kNPerXDL; // 21. NPerXDL
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
oss << "," << kDoPadGemmM;
oss << "," << kDoPadGemmN;
oss << "," << kNumGemmKPrefetchStage;
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kK0PerBlock; // 18. K0PerBlock
oss << "," << kAK1; // 19. AK1
oss << "," << kBK1; // 19,5. BK1
oss << "," << kMPerXDL; // 20. MPerXDL
oss << "," << kNPerXDL; // 21. NPerXDL
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 24.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
@@ -328,10 +331,13 @@ struct InstanceTraits<
oss << ","
<< detail::sequence_name<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40.
oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 41.
oss << "," << detail::type_name<ComputeTypeA>(); // 42.
oss << "," << detail::type_name<ComputeTypeB>(); // 43.
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 44. LoopSched
oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 42.
oss << "," << kNumGemmKPrefetchStage; // 41.
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 43. LoopSched
oss << "," << detail::type_name<ComputeTypeA>(); // 44.
oss << "," << detail::type_name<ComputeTypeB>(); // 45.
oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 46.
oss << "," << kMaxTransposeTransferDstScalarPerVector; // 47.
oss << ">";

View File

@@ -18,6 +18,7 @@
#include <string_view>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"

View File

@@ -195,6 +195,7 @@ target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility)
add_ck_builder_test(test_ckb_build_bwd_data_instances
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp
conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_data_multi_d_wmma_cshuffle.cpp
)
target_link_libraries(test_ckb_build_bwd_data_instances PRIVATE utility)

View File

@@ -0,0 +1,43 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gmock/gmock.h"
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_DATA,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT)
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
.with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdData_2DFp16_MultiD_Wmma_CShuffle_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle",
expected_transfer_parameters,
"Default",
"GNHWK,GKYXC,EmptyTuple,GNHWC",
"PassThrough,PassThrough,PassThrough",
"fp16,fp16"}); // check compute types
}

View File

@@ -31,14 +31,14 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdData_2DFp16_MultiD_CShuffle_GNHWC, Create)
TEST(BwdData_2DFp16_MultiD_Xdl_CShuffle_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle",
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"GNHWK,GKYXC,EmptyTuple,GNHWC",
"PassThrough,PassThrough,PassThrough",
"fp16,fp16"}); // check compute types
}

View File

@@ -738,4 +738,15 @@ using ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle =
Prefetch_,
TransposeParams_,
GemmPad_>;
// Bwd Data algorithm types
using ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationBwdData_,
GridGemm_,
MultipleDSpecialization_,
Prefetch_>;
} // namespace ck_tile::builder::test

View File

@@ -529,4 +529,15 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle>(
ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
return oss.str();
}
} // namespace ck_tile::builder::test

View File

@@ -20,6 +20,11 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -826,6 +831,24 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
return str.str();
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};
} // namespace device