mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
first instance of bwd data factory
This commit is contained in:
@@ -187,6 +187,14 @@ concept GridwiseBwdXdlGemmDescriptor = requires(T t) {
|
||||
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept GridwiseBwdDataXdlGemmDescriptor = requires(T t) {
|
||||
{ t.ak1 } -> SizeType;
|
||||
{ t.bk1 } -> SizeType;
|
||||
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseFwdXdlGemm = requires(T t) {
|
||||
@@ -199,6 +207,12 @@ concept SpecifiesGridwiseBwdXdlGemm = requires(T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseBwdDataXdlGemm = requires(T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseBwdDataXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise WMMA GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseWmmaGemm = requires(T t) {
|
||||
@@ -292,6 +306,11 @@ concept SpecifiesBwdWeightConvSpecialization = requires {
|
||||
{ T::bwd_weight_specialization } -> std::convertible_to<ConvSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesBwdDataConvSpecialization = requires {
|
||||
{ T::bwd_data_specialization } -> std::convertible_to<ConvSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesGemmSpecialization = requires {
|
||||
{ T::gemm_specialization } -> std::convertible_to<GemmSpecialization>;
|
||||
|
||||
@@ -29,23 +29,27 @@ concept FwdXdlAlgorithmBase =
|
||||
template <typename T>
|
||||
concept BwdXdlAlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters4D<T> &&
|
||||
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
|
||||
(SpecifiesGridwiseBwdXdlGemm<T> || SpecifiesGridwiseBwdDataXdlGemm<T>) &&
|
||||
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>);
|
||||
|
||||
template <typename T>
|
||||
concept BwdXdlV3AlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
|
||||
(SpecifiesGridwiseBwdXdlGemm<T> || SpecifiesGridwiseBwdDataXdlGemm<T>) &&
|
||||
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>) &&
|
||||
SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdWmmaAlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
|
||||
SpecifiesGridwiseWmmaGemm<T> &&
|
||||
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>);
|
||||
|
||||
template <typename T>
|
||||
concept BwdWmmaV3AlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> &&
|
||||
(SpecifiesBwdWeightConvSpecialization<T> || SpecifiesBwdDataConvSpecialization<T>) &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
|
||||
// Reference algorithm concept
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_V1 instance
|
||||
// of a grouped bwd Data convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardData<SIGNATURE>
|
||||
struct ConvBwdDataMultiDXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdDataConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The backward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::OutLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Types::OutDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::InDataType,
|
||||
typename Ops::OutElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::InElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
ALGORITHM.DoPadGemmM,
|
||||
ALGORITHM.DoPadGemmN,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
LOOP_SCHEDULER,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -77,6 +77,7 @@
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
@@ -151,10 +152,19 @@ constexpr auto make_conv_instance()
|
||||
// Backward data direction (will expand with more algorithms in the future)
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
static_assert(false,
|
||||
"Backward data convolution: Only reference and tile algorithms supported "
|
||||
"currently. "
|
||||
"Optimized kernels (XDL, WMMA, etc.) not yet implemented.");
|
||||
if constexpr(BwdMultiDXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdDataMultiDXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"No suitable backward data convolution kernel factory found for the provided "
|
||||
"ALGORITHM. "
|
||||
"The ALGORITHM must satisfy requirements for one of: Reference, Tile, XDL V3, XDL, "
|
||||
"WMMA, DL (NHWC layout), or Large Tensor variant.");
|
||||
}
|
||||
}
|
||||
// Backward weight direction (will expand with more algorithms in the future)
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
@@ -180,4 +181,24 @@ SetBwdWeightConvSpecialization()
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionBackwardDataSpecialization
|
||||
SetBwdDataConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.bwd_data_specialization;
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvSpecialization::FILTER_1X1_PAD0:
|
||||
throw "FILTER_1x1_PAD0 is not supported for backward data convolution.";
|
||||
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvSpecialization::ODD_C:
|
||||
throw "FILTER ODD_C is not supported for backward data convolution.";
|
||||
case ConvSpecialization::FILTER_3x3:
|
||||
throw "FILTER_3x3 is not supported for backward data convolution.";
|
||||
default: throw "Unsupported ConvSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -178,6 +178,7 @@ set(BWD_WEIGHT_TESTS
|
||||
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_dl.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
|
||||
conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp
|
||||
)
|
||||
|
||||
if (CK_USE_WMMA)
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "gmock/gmock.h"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_DATA,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_256_128x128x8)
|
||||
.with_gemm_config(cku::BwdDataGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x64x1)
|
||||
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
|
||||
.with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_gemm_pad_params(0, 0)
|
||||
.with_transpose_params(2, 2);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdData_2DFp16_MultiD_CShuffle_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16>"}); // check compute types
|
||||
}
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
@@ -54,6 +55,13 @@ struct GridwiseBwdXdlGemm
|
||||
};
|
||||
static_assert(ckb::GridwiseBwdXdlGemmDescriptor<GridwiseBwdXdlGemm>);
|
||||
|
||||
struct GridwiseBwdDataXdlGemm
|
||||
{
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
XdlParams xdl_params;
|
||||
};
|
||||
|
||||
// Describe gridwise WMMA GEMM parameters.
|
||||
struct GridwiseWmmaGemm
|
||||
{
|
||||
@@ -209,6 +217,11 @@ struct BwdXdlGemm_
|
||||
GridwiseBwdXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct BwdDataXdlGemm_
|
||||
{
|
||||
GridwiseBwdDataXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct WmmaGemm_
|
||||
{
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
@@ -231,12 +244,23 @@ struct ConvSpecializationBwdWeight_
|
||||
ConvSpecialization bwd_weight_specialization;
|
||||
};
|
||||
|
||||
struct ConvSpecializationBwdData_
|
||||
{
|
||||
ConvSpecialization bwd_data_specialization;
|
||||
};
|
||||
|
||||
struct Prefetch_
|
||||
{
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
|
||||
struct GemmPad_
|
||||
{
|
||||
size_t DoPadGemmM;
|
||||
size_t DoPadGemmN;
|
||||
};
|
||||
|
||||
struct TransposeParams_
|
||||
{
|
||||
size_t max_transpose_transfer_src_scalar_per_vector{1};
|
||||
@@ -394,6 +418,10 @@ struct ConvAlgorithmTemplate : Components...
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else if constexpr(std::is_base_of_v<BwdDataXdlGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else if constexpr(std::is_base_of_v<WmmaGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
@@ -433,6 +461,14 @@ struct ConvAlgorithmTemplate : Components...
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_bwd_data_specialization(ConvSpecialization bwd_spec) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<ConvSpecializationBwdData_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.bwd_data_specialization = bwd_spec;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<Prefetch_, ConvAlgorithmTemplate>);
|
||||
@@ -452,6 +488,15 @@ struct ConvAlgorithmTemplate : Components...
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_gemm_pad_params(size_t doPadGemmN_, size_t doPadGemmM_) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<GemmPad_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.DoPadGemmN = doPadGemmN_;
|
||||
result.DoPadGemmM = doPadGemmM_;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<GemmBatchOptions_, ConvAlgorithmTemplate>);
|
||||
@@ -683,4 +728,14 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 =
|
||||
BlockGemm_,
|
||||
MultipleDSpecialization_>;
|
||||
|
||||
// Bwd Data algorithm types
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
BwdDataXdlGemm_,
|
||||
Transfer_<4>,
|
||||
ConvSpecializationBwdData_,
|
||||
MultipleDSpecialization_,
|
||||
Prefetch_,
|
||||
TransposeParams_,
|
||||
GemmPad_>;
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -249,6 +249,26 @@ constexpr Transfer<> Transfer_4x32x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}};
|
||||
|
||||
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
@@ -85,6 +85,15 @@ inline std::string to_string<ThreadBlock>(ThreadBlock t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseBwdDataXdlGemm>(GridwiseBwdDataXdlGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl
|
||||
<< "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseBwdXdlGemm>(GridwiseBwdXdlGemm t)
|
||||
{
|
||||
@@ -283,6 +292,12 @@ inline std::string to_string<BwdXdlGemm_>(BwdXdlGemm_ t)
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BwdDataXdlGemm_>(BwdDataXdlGemm_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
|
||||
{
|
||||
@@ -311,6 +326,14 @@ inline std::string to_string<ConvSpecializationBwdWeight_>(ConvSpecializationBwd
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvSpecializationBwdData_>(ConvSpecializationBwdData_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.bwd_data_specialization);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<Prefetch_>(Prefetch_ t)
|
||||
{
|
||||
@@ -495,4 +518,15 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_X
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
|
||||
<< to_string(static_cast<BwdDataXdlGemm_>(t)) << ","
|
||||
<< to_string(static_cast<Transfer_<4>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
Reference in New Issue
Block a user