Dispatching for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle.

This commit is contained in:
Ville Pietilä
2026-01-02 09:25:38 -05:00
parent 1759db7250
commit c3a9044bad
6 changed files with 156 additions and 54 deletions

View File

@@ -281,7 +281,7 @@ struct LargeTensorAlgorithm : public FwdXdlAlgorithmBase<T>
};
template <typename T>
struct BwdXdlAlgorithm {
struct BwdXdlAlgorithmBase {
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
CHECK_CONCEPT(T, SpecifiesThreadBlock)
CHECK_CONCEPT(T, SpecifiesBlockTransfer4D)
@@ -290,8 +290,6 @@ struct BwdXdlAlgorithm {
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm)
CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization)
CHECK_CONCEPT(T, TransposeTransferWellDefinedIfProvided)
CHECK_CONCEPT(T, SpecifiesGenericInstance)
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
static constexpr bool c2 = c_SpecifiesThreadBlock;
@@ -301,16 +299,13 @@ struct BwdXdlAlgorithm {
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm;
static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization;
static constexpr bool c9 = c_TransposeTransferWellDefinedIfProvided;
static constexpr bool c10 = c_SpecifiesGenericInstance;
static consteval bool is_valid() {
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10;
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8;
}
static consteval auto message() -> std::string {
return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n"
"Concepts for BwdXdl Algorithm:\n") +
return
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) +
@@ -318,49 +313,45 @@ struct BwdXdlAlgorithm {
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) +
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) +
DIAGNOSTIC_LINE(TransposeTransferWellDefinedIfProvided) +
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization);
}
};
template <typename T>
struct BwdXdlAlgorithm : public BwdXdlAlgorithmBase<T>{
CHECK_CONCEPT(T, SpecifiesTransposeTransfer)
CHECK_CONCEPT(T, SpecifiesGenericInstance)
static constexpr bool c9 = c_SpecifiesTransposeTransfer;
static constexpr bool c10 = c_SpecifiesGenericInstance;
static consteval bool is_valid() {
return c9 && c10 && BwdXdlAlgorithmBase<T>::is_valid();
}
static consteval auto message() -> std::string {
return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n"
"Concepts for BwdXdl Algorithm:\n") +
BwdXdlAlgorithmBase<T>::message() +
DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) +
DIAGNOSTIC_LINE(SpecifiesGenericInstance);
}
};
template <typename T>
struct BwdMultiDXdlAlgorithm {
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
CHECK_CONCEPT(T, SpecifiesThreadBlock)
CHECK_CONCEPT(T, SpecifiesBlockTransfer4D)
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm)
CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization)
struct BwdMultiDXdlAlgorithm : public BwdXdlAlgorithmBase<T>{
CHECK_CONCEPT(T, SpecifiesMultipleDSupport)
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
static constexpr bool c2 = c_SpecifiesThreadBlock;
static constexpr bool c3 = c_SpecifiesBlockTransfer4D;
static constexpr bool c4 = c_SpecifiesLdsTransfer;
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm;
static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization;
static constexpr bool c9 = c_SpecifiesMultipleDSupport;
static consteval bool is_valid() {
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9;
return c9 && BwdXdlAlgorithmBase<T>::is_valid();
}
static consteval auto message() -> std::string {
return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n"
"Concepts for BwdXdl Algorithm:\n") +
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) +
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) +
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) +
BwdXdlAlgorithmBase<T>::message() +
DIAGNOSTIC_LINE(SpecifiesMultipleDSupport);
}
};
@@ -448,7 +439,7 @@ struct BwdTwoStageXdlAlgorithm : public BwdXdlV3AlgorithmBase<T>{
};
template <typename T>
struct BwdWmmaAlgorithm {
struct BwdWmmaAlgorithmBase {
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
CHECK_CONCEPT(T, SpecifiesThreadBlock)
CHECK_CONCEPT(T, SpecifiesBlockTransfer)
@@ -457,10 +448,6 @@ struct BwdWmmaAlgorithm {
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm)
CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization)
CHECK_CONCEPT(T, SpecifiesNumPrefetchStages)
CHECK_CONCEPT(T, SpecifiesLoopScheduler)
CHECK_CONCEPT(T, SpecifiedGridwiseGemmPipeline)
CHECK_CONCEPT(T, SpecifiesGenericInstance)
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
static constexpr bool c2 = c_SpecifiesThreadBlock;
@@ -470,18 +457,13 @@ struct BwdWmmaAlgorithm {
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm;
static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization;
static constexpr bool c9 = c_SpecifiesNumPrefetchStages;
static constexpr bool c10 = c_SpecifiesLoopScheduler;
static constexpr bool c11 = c_SpecifiedGridwiseGemmPipeline;
static constexpr bool c12 = c_SpecifiesGenericInstance;
static consteval bool is_valid() {
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12;
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8;
}
static consteval auto message() -> std::string {
return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n"
"Concepts for BwdWmma Algorithm:\n") +
return
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
DIAGNOSTIC_LINE(SpecifiesBlockTransfer) +
@@ -489,7 +471,30 @@ struct BwdWmmaAlgorithm {
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) +
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) +
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization);
}
};
template <typename T>
struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase<T> {
CHECK_CONCEPT(T, SpecifiesNumPrefetchStages)
CHECK_CONCEPT(T, SpecifiesLoopScheduler)
CHECK_CONCEPT(T, SpecifiedGridwiseGemmPipeline)
CHECK_CONCEPT(T, SpecifiesGenericInstance)
static constexpr bool c9 = c_SpecifiesNumPrefetchStages;
static constexpr bool c10 = c_SpecifiesLoopScheduler;
static constexpr bool c11 = c_SpecifiedGridwiseGemmPipeline;
static constexpr bool c12 = c_SpecifiesGenericInstance;
static consteval bool is_valid() {
return c9 && c10 && c11 && c12 && BwdWmmaAlgorithmBase<T>::is_valid();
}
static consteval auto message() -> std::string {
return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n"
"Concepts for BwdWmma Algorithm:\n") +
BwdWmmaAlgorithmBase<T>::message() +
DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) +
DIAGNOSTIC_LINE(SpecifiesLoopScheduler) +
DIAGNOSTIC_LINE(SpecifiedGridwiseGemmPipeline) +
@@ -497,6 +502,27 @@ struct BwdWmmaAlgorithm {
}
};
template <typename T>
struct BwdMultiDWmmaAlgorithm : public BwdWmmaAlgorithmBase<T> {
CHECK_CONCEPT(T, SpecifiesBlockGemm)
CHECK_CONCEPT(T, SpecifiesMultipleDSupport)
static constexpr bool c9 = c_SpecifiesBlockGemm;
static constexpr bool c10 = c_SpecifiesMultipleDSupport;
static consteval bool is_valid() {
return c9 && c10 && BwdWmmaAlgorithmBase<T>::is_valid();
}
static consteval auto message() -> std::string {
return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n"
"Concepts for BwdMultiDWmma Algorithm:\n") +
BwdWmmaAlgorithmBase<T>::message() +
DIAGNOSTIC_LINE(SpecifiesBlockGemm) +
DIAGNOSTIC_LINE(SpecifiesMultipleDSupport);
}
};
template <typename T>
struct BwdWmmaV3AlgorithmBase {
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
@@ -508,7 +534,7 @@ struct BwdWmmaV3AlgorithmBase {
CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm)
CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization)
CHECK_CONCEPT(T, SpecifiesBlockGemm)
CHECK_CONCEPT(T, TransposeTransferWellDefinedIfProvided)
CHECK_CONCEPT(T, SpecifiesTransposeTransfer)
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
static constexpr bool c2 = c_SpecifiesThreadBlock;
@@ -519,7 +545,7 @@ struct BwdWmmaV3AlgorithmBase {
static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm;
static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization;
static constexpr bool c9 = c_SpecifiesBlockGemm;
static constexpr bool c10 = c_TransposeTransferWellDefinedIfProvided;
static constexpr bool c10 = c_SpecifiesTransposeTransfer;
static consteval bool is_valid() {
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10;
@@ -536,7 +562,7 @@ struct BwdWmmaV3AlgorithmBase {
DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) +
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) +
DIAGNOSTIC_LINE(SpecifiesBlockGemm) +
DIAGNOSTIC_LINE(TransposeTransferWellDefinedIfProvided);
DIAGNOSTIC_LINE(SpecifiesTransposeTransfer);
}
};
@@ -670,6 +696,12 @@ consteval int count_matches_bwd_wmma() {
return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12;
}
template <typename T>
consteval int count_matches_bwd_multi_d_wmma() {
using Alg = BwdMultiDWmmaAlgorithm<T>;
return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12;
}
template <typename T>
consteval int count_matches_bwd_wmma_v3() {
using Alg = BwdWmmaV3Algorithm<T>;
@@ -785,6 +817,7 @@ consteval void diagnose_bwd_weight_algorithm_signature()
constexpr int wmma_v3_matches = count_matches_bwd_wmma_v3<AlgoType>();
constexpr int two_stage_wmma_v3_matches = count_matches_bwd_two_stage_wmma_v3<AlgoType>();
constexpr int wmma_matches = count_matches_bwd_wmma<AlgoType>();
constexpr int multi_d_wmma_matches = count_matches_bwd_multi_d_wmma<AlgoType>();
// Check whether we have XDL or WMMA algorithm
if constexpr (SpecifiesGridwiseBwdXdlGemm<AlgoType>)
@@ -818,7 +851,9 @@ consteval void diagnose_bwd_weight_algorithm_signature()
else if constexpr (SpecifiesGridwiseWmmaGemm<AlgoType>)
{
constexpr int max_1 = wmma_v3_matches > two_stage_wmma_v3_matches ? wmma_v3_matches : two_stage_wmma_v3_matches;
constexpr int max_matches = max_1 > wmma_matches ? max_1 : wmma_matches;
constexpr int max_2 = max_1 > wmma_matches ? max_1 : wmma_matches;
constexpr int max_matches = multi_d_wmma_matches > max_2 ? multi_d_wmma_matches : max_2;
if constexpr (max_matches == wmma_v3_matches) {
using Alg = BwdWmmaV3Algorithm<AlgoType>;
static_assert(Alg::is_valid(), Alg::message());
@@ -831,6 +866,10 @@ consteval void diagnose_bwd_weight_algorithm_signature()
using Alg = BwdWmmaAlgorithm<AlgoType>;
static_assert(Alg::is_valid(), Alg::message());
}
else if constexpr (max_matches == multi_d_wmma_matches) {
using Alg = BwdMultiDWmmaAlgorithm<AlgoType>;
static_assert(Alg::is_valid(), Alg::message());
}
}
else
{

View File

@@ -186,6 +186,11 @@ constexpr auto make_conv_instance()
{
return typename ConvBwdWeightWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr (BwdMultiDWmmaAlgorithm<AlgoType>::is_valid())
{
static_assert(false,
"Backward weight convolution with multi-D WMMA algorithm is not yet supported.");
}
else
{
diagnose_bwd_weight_algorithm_signature<AlgoType>();

View File

@@ -178,6 +178,7 @@ if (CK_USE_WMMA)
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp
conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp
)
endif()

View File

@@ -0,0 +1,42 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle{}
.with_thread_block(cku::ThreadBlock_256_128x128x8)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x64x1)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DFp16_MultiD_Wmma_Shuffle_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle",
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough",
"fp16,fp16>"}); // check compute types
}

View File

@@ -547,7 +547,7 @@ struct ConvAlgorithmTemplate : Components...
}
};
// Algorithm types
// Fwd algorithm types
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_<>, ConvSpecializationFwd_, Prefetch_>;
@@ -568,6 +568,7 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
ConvAlgorithmTemplate<ThreadBlock_, FwdXdlGemm_, Transfer_<>, ConvSpecializationFwd_, Prefetch_, LargeTensorSpecialization_>;
// CK Tile algorithm
using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileThreadBlock_,
TileBlockGemm_,
TileTransfer_,
@@ -583,6 +584,7 @@ struct ConvAlgorithm_Reference
// GPU reference uses simple algorithm, no tile configuration needed
};
// Bwd weight algorithm types
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<4>, ConvSpecializationBwdWeight_, TransposeParams_>;
@@ -607,5 +609,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 =
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_<4>, ConvSpecializationBwdWeight_, GridGemm_, Prefetch_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_<4>, ConvSpecializationBwdWeight_, BlockGemm_, MultipleDSpecialization_>;
} // namespace ck_tile::builder::test

View File

@@ -443,6 +443,16 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuf
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle>(
ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_<4>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t)