Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.

This commit is contained in:
Ville Pietilä
2025-12-31 04:32:28 -05:00
parent 30c10e2544
commit 75710202ab
11 changed files with 310 additions and 27 deletions

View File

@@ -321,12 +321,29 @@ concept SpecifiesLargeTensorSupport = requires {
requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR;
};
template <typename T>
concept SpecifiesTwoStageSupport = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE;
};
template <typename T>
concept SpecifiesGenericInstance = !requires {
{ T::specialization };
};
template <typename T>
concept SpecifiesTransposeTransfer = requires {
{ T::max_transpose_transfer_src_scalar_per_vector } -> SizeType;
{ T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType;
};
template <typename T>
concept SpecifiesGemmBatchOptions = requires {
{ T::num_conv_groups_to_merge } -> SizeType;
};
/******************************************** */
/* DL-specific descriptors and requirements */
/******************************************** */

View File

@@ -712,6 +712,36 @@ consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string
return msg;
}
template <typename T>
consteval auto detailed_diagnostic_SpecifiesTwoStageSupport() -> std::string {
std::string msg;
if constexpr (requires { T::specialization; }) {
using SpecType = decltype(T::specialization);
constexpr bool convertible = std::convertible_to<SpecType, ConvAlgorithmSpecialization>;
msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) +
(convertible ? "" : std::string(detail::get_type_info<SpecType>())) + "\n";
if constexpr (convertible) {
constexpr bool is_two_stage = (T::specialization == ConvAlgorithmSpecialization::TWO_STAGE);
msg += " → specialization == TWO_STAGE: " + std::string(CHECK_MARK(is_two_stage)) + "\n";
}
} else {
msg += " → T::specialization: [✗] (missing member)\n";
}
return msg;
}
template <typename T>
consteval auto detailed_diagnostic_SpecifiesGenericInstance() -> std::string {
std::string msg;
if constexpr (requires { T::specialization; }) {
msg += " → T::specialization: [✗] (member should NOT exist for generic instance)\n";
msg += " → This concept requires the absence of the specialization member\n";
}
return msg;
}
template <typename T>
consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string {
std::string msg;

View File

@@ -290,6 +290,7 @@ struct BwdXdlV3Algorithm {
CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm)
CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization)
CHECK_CONCEPT(T, SpecifiesBlockGemm)
CHECK_CONCEPT(T, SpecifiesGenericInstance)
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
static constexpr bool c2 = c_SpecifiesThreadBlock;
@@ -300,9 +301,10 @@ struct BwdXdlV3Algorithm {
static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm;
static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization;
static constexpr bool c9 = c_SpecifiesBlockGemm;
static constexpr bool c10 = c_SpecifiesGenericInstance;
static consteval bool is_valid() {
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9;
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10;
}
static consteval auto message() -> std::string {
@@ -316,7 +318,58 @@ struct BwdXdlV3Algorithm {
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) +
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) +
DIAGNOSTIC_LINE(SpecifiesBlockGemm);
DIAGNOSTIC_LINE(SpecifiesBlockGemm) +
DIAGNOSTIC_LINE(SpecifiesGenericInstance);
}
};
template <typename T>
struct BwdTwoStageXdlAlgorithm {
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
CHECK_CONCEPT(T, SpecifiesThreadBlock)
CHECK_CONCEPT(T, SpecifiesBlockTransfer)
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm)
CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization)
CHECK_CONCEPT(T, SpecifiesBlockGemm)
CHECK_CONCEPT(T, SpecifiesTransposeTransfer)
CHECK_CONCEPT(T, SpecifiesGemmBatchOptions)
CHECK_CONCEPT(T, SpecifiesTwoStageSupport)
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
static constexpr bool c2 = c_SpecifiesThreadBlock;
static constexpr bool c3 = c_SpecifiesBlockTransfer;
static constexpr bool c4 = c_SpecifiesLdsTransfer;
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm;
static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization;
static constexpr bool c9 = c_SpecifiesBlockGemm;
static constexpr bool c10 = c_SpecifiesTransposeTransfer;
static constexpr bool c11 = c_SpecifiesGemmBatchOptions;
static constexpr bool c12 = c_SpecifiesTwoStageSupport;
static consteval bool is_valid() {
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && & c10 && c11 && c12;
}
static consteval auto message() -> std::string {
return std::string("\n=== Backward two stage XDL Algorithm Diagnostic (closest match) ===\n"
"Concepts for BwdXdlV3 Algorithm:\n") +
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
DIAGNOSTIC_LINE(SpecifiesBlockTransfer) +
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) +
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) +
DIAGNOSTIC_LINE(SpecifiesBlockGemm) +
DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) +
DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) +
DIAGNOSTIC_LINE(SpecifiesTwoStageSupport);
}
};
@@ -356,6 +409,12 @@ consteval int count_matches_bwd_xdl_v3() {
return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9;
}
template <typename T>
consteval int count_matches_bwd_two_stage_xdl() {
using Alg = BwdTwoStageXdlAlgorithm<T>;
return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12;
}
template <typename T>
consteval int count_matches_large_tensor() {
using Alg = LargeTensorAlgorithm<T>;
@@ -417,8 +476,10 @@ consteval void diagnose_bwd_weight_algorithm_signature()
{
constexpr int xdl_matches = count_matches_bwd_xdl<AlgoType>();
constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3<AlgoType>();
constexpr int two_stage_xdl_matches = count_matches_bwd_two_stage_xdl<AlgoType>();
constexpr int max_matches = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches;
constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches;
constexpr int max_matches = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches;
if constexpr (max_matches == xdl_matches) {
using Alg = BwdXdlAlgorithm<AlgoType>;
@@ -428,6 +489,10 @@ consteval void diagnose_bwd_weight_algorithm_signature()
using Alg = BwdXdlV3Algorithm<AlgoType>;
static_assert(Alg::is_valid(), Alg::message());
}
else if constexpr (max_matches == two_stage_xdl_matches) {
using Alg = BwdTwoStageXdlAlgorithm<AlgoType>;
static_assert(Alg::is_valid(), Alg::message());
}
else {
// This should never happen
static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics.");

View File

@@ -0,0 +1,106 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
// of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightTwoStageXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>, "Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>, "Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>, "Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>, "Invalid B source access order");
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
ALGORITHM.num_conv_groups_to_merge,
typename Types::InComputeType,
typename Types::WeiComputeType,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -49,12 +49,6 @@
#pragma once
// Disable pragma message warnings for factory selection diagnostics
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-W#pragma-messages"
#endif
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/types.hpp"
@@ -71,6 +65,7 @@
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp"
namespace ck_tile::builder::factory {
@@ -103,34 +98,28 @@ constexpr auto make_conv_instance()
// CK Tile supports common factory for each direction
if constexpr(TileAlgorithm<AlgoType>::is_valid())
{
#pragma message("[CK Builder] Using ConvTileFactory...")
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
if constexpr(FwdXdlV3Algorithm<AlgoType>::is_valid())
{
#pragma message("[CK Builder] Using ConvFwdXdlV3Factory...")
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(FwdXdlAlgorithm<AlgoType>::is_valid())
{
#pragma message("[CK Builder] Using ConvFwdXdlFactory...")
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(FwdWmmaAlgorithm<AlgoType>::is_valid())
{
#pragma message("[CK Builder] Using ConvFwdWmmaFactory...")
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(FwdDlAlgorithm<AlgoType>::is_valid())
{
#pragma message("[CK Builder] Using ConvFwdDlFactory...")
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(LargeTensorAlgorithm<AlgoType>::is_valid())
{
#pragma message("[CK Builder] Using ConvFwdLargeTensorFactory...")
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else
@@ -148,14 +137,16 @@ constexpr auto make_conv_instance()
{
if constexpr (BwdXdlAlgorithm<AlgoType>::is_valid())
{
#pragma message("[CK Builder] Using ConvBwdWeightXdlFactory...")
return typename ConvBwdWeightXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr (BwdXdlV3Algorithm<AlgoType>::is_valid())
{
#pragma message("[CK Builder] Using ConvBwdWeightXdlV3Factory...")
return typename ConvBwdWeightXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr (BwdTwoStageXdlAlgorithm<AlgoType>::is_valid())
{
return typename ConvBwdWeightTwoStageXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else
{
diagnose_bwd_weight_algorithm_signature<AlgoType>();
@@ -171,8 +162,3 @@ constexpr auto make_conv_instance()
}
} // namespace ck_tile::builder::factory
// Re-enable pragma message warnings
#ifdef __clang__
#pragma clang diagnostic pop
#endif

View File

@@ -232,7 +232,8 @@ enum class PipelineScheduler
enum class ConvAlgorithmSpecialization
{
LARGE_TENSOR
LARGE_TENSOR,
TWO_STAGE
};
// toString methods for enum classes

View File

@@ -151,6 +151,7 @@ target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
add_ck_builder_test(test_ckb_build_bwd_weight_instances
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
)

View File

@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
.with_num_conv_groups_to_merge(2)
.with_transpose_params(2, 4);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DBf16_TwoStage_CShuffle, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle",
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough",
"Intrawave,v2", // pipeline versions
"bf16,bf16,2,4>"}); // compute types and transpose params
}

View File

@@ -23,7 +23,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh
.with_thread_block(cku::ThreadBlock_256_128x128x8)
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(cku::BwdTransfer_4x64x1)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT);
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_transpose_params(2, 2);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
@@ -35,5 +36,6 @@ TEST(BwdWeight_2DFp16_CShuffle_GNHWC, Create)
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough"});
"PassThrough,PassThrough,PassThrough",
"fp16,fp16,2,2>"}); // check compute types and transpose params
}

View File

@@ -243,6 +243,11 @@ struct TransposeParams_
size_t max_transpose_transfer_dst_scalar_per_vector{1};
};
struct GemmBatchOptions_
{
size_t num_conv_groups_to_merge{1};
};
struct BlockGemm_
{
BlockGemm block_gemm;
@@ -280,6 +285,11 @@ struct DlTransfer_
DlTransferABC transfer;
};
struct TwoStageSpecialization_
{
static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::TWO_STAGE;
};
// Specialization wrapper for large tensor support
template <typename BaseAlgorithm>
struct LargeTensorWrapper
@@ -433,8 +443,8 @@ struct ConvAlgorithmTemplate : Components...
return result;
}
constexpr auto with_transpose_params(bool max_src_scalar_per_vector,
bool max_dst_scalar_per_vector) const
constexpr auto with_transpose_params(size_t max_src_scalar_per_vector,
size_t max_dst_scalar_per_vector) const
{
static_assert(std::is_base_of_v<TransposeParams_, ConvAlgorithmTemplate>);
auto result = *this;
@@ -443,6 +453,14 @@ struct ConvAlgorithmTemplate : Components...
return result;
}
constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const
{
static_assert(std::is_base_of_v<GemmBatchOptions_, ConvAlgorithmTemplate>);
auto result = *this;
result.num_conv_groups_to_merge = num_groups_to_merge;
return result;
}
template <typename BG>
constexpr auto with_block_gemm(const BG& bg) const
{
@@ -555,6 +573,9 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileTh
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<4>, ConvSpecializationBwdWeight_, TransposeParams_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<>, ConvSpecializationBwdWeight_, BlockGemm_, TransposeParams_, GemmBatchOptions_, TwoStageSpecialization_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<>, ConvSpecializationBwdWeight_, BlockGemm_>;

View File

@@ -397,4 +397,14 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuff
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
} // namespace ck_tile::builder::test