mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.
This commit is contained in:
@@ -321,12 +321,29 @@ concept SpecifiesLargeTensorSupport = requires {
|
||||
requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTwoStageSupport = requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesGenericInstance = !requires {
|
||||
{ T::specialization };
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTransposeTransfer = requires {
|
||||
{ T::max_transpose_transfer_src_scalar_per_vector } -> SizeType;
|
||||
{ T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesGemmBatchOptions = requires {
|
||||
{ T::num_conv_groups_to_merge } -> SizeType;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* DL-specific descriptors and requirements */
|
||||
/******************************************** */
|
||||
|
||||
@@ -712,6 +712,36 @@ consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string
|
||||
return msg;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
consteval auto detailed_diagnostic_SpecifiesTwoStageSupport() -> std::string {
|
||||
std::string msg;
|
||||
if constexpr (requires { T::specialization; }) {
|
||||
using SpecType = decltype(T::specialization);
|
||||
constexpr bool convertible = std::convertible_to<SpecType, ConvAlgorithmSpecialization>;
|
||||
msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) +
|
||||
(convertible ? "" : std::string(detail::get_type_info<SpecType>())) + "\n";
|
||||
|
||||
if constexpr (convertible) {
|
||||
constexpr bool is_two_stage = (T::specialization == ConvAlgorithmSpecialization::TWO_STAGE);
|
||||
msg += " → specialization == TWO_STAGE: " + std::string(CHECK_MARK(is_two_stage)) + "\n";
|
||||
}
|
||||
} else {
|
||||
msg += " → T::specialization: [✗] (missing member)\n";
|
||||
}
|
||||
|
||||
return msg;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
consteval auto detailed_diagnostic_SpecifiesGenericInstance() -> std::string {
|
||||
std::string msg;
|
||||
if constexpr (requires { T::specialization; }) {
|
||||
msg += " → T::specialization: [✗] (member should NOT exist for generic instance)\n";
|
||||
msg += " → This concept requires the absence of the specialization member\n";
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string {
|
||||
std::string msg;
|
||||
|
||||
@@ -290,6 +290,7 @@ struct BwdXdlV3Algorithm {
|
||||
CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesGenericInstance)
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesThreadBlock;
|
||||
@@ -300,9 +301,10 @@ struct BwdXdlV3Algorithm {
|
||||
static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm;
|
||||
static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization;
|
||||
static constexpr bool c9 = c_SpecifiesBlockGemm;
|
||||
static constexpr bool c10 = c_SpecifiesGenericInstance;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9;
|
||||
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10;
|
||||
}
|
||||
|
||||
static consteval auto message() -> std::string {
|
||||
@@ -316,7 +318,58 @@ struct BwdXdlV3Algorithm {
|
||||
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockGemm);
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGenericInstance);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BwdTwoStageXdlAlgorithm {
|
||||
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadBlock)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesTransposeTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesGemmBatchOptions)
|
||||
CHECK_CONCEPT(T, SpecifiesTwoStageSupport)
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesThreadBlock;
|
||||
static constexpr bool c3 = c_SpecifiesBlockTransfer;
|
||||
static constexpr bool c4 = c_SpecifiesLdsTransfer;
|
||||
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
|
||||
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
|
||||
static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm;
|
||||
static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization;
|
||||
static constexpr bool c9 = c_SpecifiesBlockGemm;
|
||||
static constexpr bool c10 = c_SpecifiesTransposeTransfer;
|
||||
static constexpr bool c11 = c_SpecifiesGemmBatchOptions;
|
||||
static constexpr bool c12 = c_SpecifiesTwoStageSupport;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && & c10 && c11 && c12;
|
||||
}
|
||||
|
||||
static consteval auto message() -> std::string {
|
||||
return std::string("\n=== Backward two stage XDL Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for BwdXdlV3 Algorithm:\n") +
|
||||
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) +
|
||||
DIAGNOSTIC_LINE(SpecifiesTwoStageSupport);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -356,6 +409,12 @@ consteval int count_matches_bwd_xdl_v3() {
|
||||
return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
consteval int count_matches_bwd_two_stage_xdl() {
|
||||
using Alg = BwdTwoStageXdlAlgorithm<T>;
|
||||
return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
consteval int count_matches_large_tensor() {
|
||||
using Alg = LargeTensorAlgorithm<T>;
|
||||
@@ -417,8 +476,10 @@ consteval void diagnose_bwd_weight_algorithm_signature()
|
||||
{
|
||||
constexpr int xdl_matches = count_matches_bwd_xdl<AlgoType>();
|
||||
constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3<AlgoType>();
|
||||
constexpr int two_stage_xdl_matches = count_matches_bwd_two_stage_xdl<AlgoType>();
|
||||
|
||||
constexpr int max_matches = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches;
|
||||
constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches;
|
||||
constexpr int max_matches = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches;
|
||||
|
||||
if constexpr (max_matches == xdl_matches) {
|
||||
using Alg = BwdXdlAlgorithm<AlgoType>;
|
||||
@@ -428,6 +489,10 @@ consteval void diagnose_bwd_weight_algorithm_signature()
|
||||
using Alg = BwdXdlV3Algorithm<AlgoType>;
|
||||
static_assert(Alg::is_valid(), Alg::message());
|
||||
}
|
||||
else if constexpr (max_matches == two_stage_xdl_matches) {
|
||||
using Alg = BwdTwoStageXdlAlgorithm<AlgoType>;
|
||||
static_assert(Alg::is_valid(), Alg::message());
|
||||
}
|
||||
else {
|
||||
// This should never happen
|
||||
static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics.");
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightTwoStageXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>, "Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>, "Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>, "Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>, "Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
ALGORITHM.num_conv_groups_to_merge,
|
||||
typename Types::InComputeType,
|
||||
typename Types::WeiComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -49,12 +49,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// Disable pragma message warnings for factory selection diagnostics
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-W#pragma-messages"
|
||||
#endif
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
@@ -71,6 +65,7 @@
|
||||
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
@@ -103,34 +98,28 @@ constexpr auto make_conv_instance()
|
||||
// CK Tile supports common factory for each direction
|
||||
if constexpr(TileAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvTileFactory...")
|
||||
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
if constexpr(FwdXdlV3Algorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvFwdXdlV3Factory...")
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(FwdXdlAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvFwdXdlFactory...")
|
||||
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(FwdWmmaAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvFwdWmmaFactory...")
|
||||
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(FwdDlAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvFwdDlFactory...")
|
||||
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(LargeTensorAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvFwdLargeTensorFactory...")
|
||||
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else
|
||||
@@ -148,14 +137,16 @@ constexpr auto make_conv_instance()
|
||||
{
|
||||
if constexpr (BwdXdlAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvBwdWeightXdlFactory...")
|
||||
return typename ConvBwdWeightXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr (BwdXdlV3Algorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvBwdWeightXdlV3Factory...")
|
||||
return typename ConvBwdWeightXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr (BwdTwoStageXdlAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
return typename ConvBwdWeightTwoStageXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else
|
||||
{
|
||||
diagnose_bwd_weight_algorithm_signature<AlgoType>();
|
||||
@@ -171,8 +162,3 @@ constexpr auto make_conv_instance()
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
|
||||
// Re-enable pragma message warnings
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
|
||||
@@ -232,7 +232,8 @@ enum class PipelineScheduler
|
||||
|
||||
enum class ConvAlgorithmSpecialization
|
||||
{
|
||||
LARGE_TENSOR
|
||||
LARGE_TENSOR,
|
||||
TWO_STAGE
|
||||
};
|
||||
|
||||
// toString methods for enum classes
|
||||
|
||||
@@ -151,6 +151,7 @@ target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_bwd_weight_instances
|
||||
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
|
||||
.with_num_conv_groups_to_merge(2)
|
||||
.with_transpose_params(2, 4);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_2DBf16_TwoStage_CShuffle, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave,v2", // pipeline versions
|
||||
"bf16,bf16,2,4>"}); // compute types and transpose params
|
||||
}
|
||||
@@ -23,7 +23,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh
|
||||
.with_thread_block(cku::ThreadBlock_256_128x128x8)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x64x1)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT);
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_transpose_params(2, 2);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
@@ -35,5 +36,6 @@ TEST(BwdWeight_2DFp16_CShuffle_GNHWC, Create)
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough"});
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16,2,2>"}); // check compute types and transpose params
|
||||
}
|
||||
|
||||
@@ -243,6 +243,11 @@ struct TransposeParams_
|
||||
size_t max_transpose_transfer_dst_scalar_per_vector{1};
|
||||
};
|
||||
|
||||
struct GemmBatchOptions_
|
||||
{
|
||||
size_t num_conv_groups_to_merge{1};
|
||||
};
|
||||
|
||||
struct BlockGemm_
|
||||
{
|
||||
BlockGemm block_gemm;
|
||||
@@ -280,6 +285,11 @@ struct DlTransfer_
|
||||
DlTransferABC transfer;
|
||||
};
|
||||
|
||||
struct TwoStageSpecialization_
|
||||
{
|
||||
static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::TWO_STAGE;
|
||||
};
|
||||
|
||||
// Specialization wrapper for large tensor support
|
||||
template <typename BaseAlgorithm>
|
||||
struct LargeTensorWrapper
|
||||
@@ -433,8 +443,8 @@ struct ConvAlgorithmTemplate : Components...
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_transpose_params(bool max_src_scalar_per_vector,
|
||||
bool max_dst_scalar_per_vector) const
|
||||
constexpr auto with_transpose_params(size_t max_src_scalar_per_vector,
|
||||
size_t max_dst_scalar_per_vector) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TransposeParams_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
@@ -443,6 +453,14 @@ struct ConvAlgorithmTemplate : Components...
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<GemmBatchOptions_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.num_conv_groups_to_merge = num_groups_to_merge;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BG>
|
||||
constexpr auto with_block_gemm(const BG& bg) const
|
||||
{
|
||||
@@ -555,6 +573,9 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileTh
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<4>, ConvSpecializationBwdWeight_, TransposeParams_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<>, ConvSpecializationBwdWeight_, BlockGemm_, TransposeParams_, GemmBatchOptions_, TwoStageSpecialization_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<>, ConvSpecializationBwdWeight_, BlockGemm_>;
|
||||
|
||||
|
||||
@@ -397,4 +397,14 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuff
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
Reference in New Issue
Block a user