mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Add factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
This commit is contained in:
@@ -502,27 +502,6 @@ struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase<T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BwdMultiDWmmaAlgorithm : public BwdWmmaAlgorithmBase<T> {
|
||||
CHECK_CONCEPT(T, SpecifiesBlockGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesMultipleDSupport)
|
||||
|
||||
static constexpr bool c9 = c_SpecifiesBlockGemm;
|
||||
static constexpr bool c10 = c_SpecifiesMultipleDSupport;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c9 && c10 && BwdWmmaAlgorithmBase<T>::is_valid();
|
||||
}
|
||||
|
||||
static consteval auto message() -> std::string {
|
||||
return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for BwdMultiDWmma Algorithm:\n") +
|
||||
BwdWmmaAlgorithmBase<T>::message() +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesMultipleDSupport);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BwdWmmaV3AlgorithmBase {
|
||||
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
|
||||
@@ -534,7 +513,6 @@ struct BwdWmmaV3AlgorithmBase {
|
||||
CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesTransposeTransfer)
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesThreadBlock;
|
||||
@@ -545,10 +523,9 @@ struct BwdWmmaV3AlgorithmBase {
|
||||
static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm;
|
||||
static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization;
|
||||
static constexpr bool c9 = c_SpecifiesBlockGemm;
|
||||
static constexpr bool c10 = c_SpecifiesTransposeTransfer;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10;
|
||||
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9;
|
||||
}
|
||||
|
||||
static consteval auto message() -> std::string {
|
||||
@@ -561,26 +538,46 @@ struct BwdWmmaV3AlgorithmBase {
|
||||
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesTransposeTransfer);
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockGemm);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BwdMultiDWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase<T> {
|
||||
CHECK_CONCEPT(T, SpecifiesMultipleDSupport)
|
||||
|
||||
static constexpr bool c10 = c_SpecifiesMultipleDSupport;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c10 && BwdWmmaAlgorithmBase<T>::is_valid();
|
||||
}
|
||||
|
||||
static consteval auto message() -> std::string {
|
||||
return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for BwdMultiDWmma Algorithm:\n") +
|
||||
BwdWmmaAlgorithmBase<T>::message() +
|
||||
DIAGNOSTIC_LINE(SpecifiesMultipleDSupport);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BwdWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase<T>
|
||||
{
|
||||
CHECK_CONCEPT(T, SpecifiesTransposeTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesGenericInstance)
|
||||
|
||||
static constexpr bool c10 = c_SpecifiesTransposeTransfer;
|
||||
static constexpr bool c11 = c_SpecifiesGenericInstance;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c11 && BwdWmmaV3AlgorithmBase<T>::is_valid();
|
||||
return c10 && c11 && BwdWmmaV3AlgorithmBase<T>::is_valid();
|
||||
}
|
||||
|
||||
static consteval auto message() -> std::string {
|
||||
return std::string("\n=== Backward WMMA V3 Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for BwdWmmaV3 Algorithm:\n") +
|
||||
BwdWmmaV3AlgorithmBase<T>::message() +
|
||||
DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGenericInstance);
|
||||
}
|
||||
};
|
||||
@@ -588,20 +585,23 @@ struct BwdWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase<T>
|
||||
template <typename T>
|
||||
struct BwdTwoStageWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase<T>
|
||||
{
|
||||
CHECK_CONCEPT(T, SpecifiesTransposeTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesTwoStageSupport)
|
||||
CHECK_CONCEPT(T, SpecifiesGemmBatchOptions)
|
||||
|
||||
static constexpr bool c10 = c_SpecifiesTransposeTransfer;
|
||||
static constexpr bool c11 = c_SpecifiesTwoStageSupport;
|
||||
static constexpr bool c12 = c_SpecifiesGemmBatchOptions;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c11 && c12 && BwdWmmaV3AlgorithmBase<T>::is_valid();
|
||||
return c10 && c11 && c12 && BwdWmmaV3AlgorithmBase<T>::is_valid();
|
||||
}
|
||||
|
||||
static consteval auto message() -> std::string {
|
||||
return std::string("\n=== Backward Two Stage WMMA V3 Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for BwdTwoStageWmmaV3 Algorithm:\n") +
|
||||
BwdWmmaV3AlgorithmBase<T>::message() +
|
||||
DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) +
|
||||
DIAGNOSTIC_LINE(SpecifiesTwoStageSupport);
|
||||
}
|
||||
@@ -698,7 +698,7 @@ consteval int count_matches_bwd_wmma() {
|
||||
|
||||
template <typename T>
|
||||
consteval int count_matches_bwd_multi_d_wmma() {
|
||||
using Alg = BwdMultiDWmmaAlgorithm<T>;
|
||||
using Alg = BwdMultiDWmmaV3Algorithm<T>;
|
||||
return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12;
|
||||
}
|
||||
|
||||
@@ -867,7 +867,7 @@ consteval void diagnose_bwd_weight_algorithm_signature()
|
||||
static_assert(Alg::is_valid(), Alg::message());
|
||||
}
|
||||
else if constexpr (max_matches == multi_d_wmma_matches) {
|
||||
using Alg = BwdMultiDWmmaAlgorithm<AlgoType>;
|
||||
using Alg = BwdMultiDWmmaV3Algorithm<AlgoType>;
|
||||
static_assert(Alg::is_valid(), Alg::message());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE> && Is3D<SIGNATURE>
|
||||
struct ConvBwdWeightMultiDWmmaV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
|
||||
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>, "Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>, "Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>, "Invalid A source access order");
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>, "Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::DsDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -76,6 +76,7 @@
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
@@ -186,10 +187,9 @@ constexpr auto make_conv_instance()
|
||||
{
|
||||
return typename ConvBwdWeightWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr (BwdMultiDWmmaAlgorithm<AlgoType>::is_valid())
|
||||
else if constexpr (BwdMultiDWmmaV3Algorithm<AlgoType>::is_valid())
|
||||
{
|
||||
static_assert(false,
|
||||
"Backward weight convolution with multi-D WMMA algorithm is not yet supported.");
|
||||
return typename ConvBwdWeightMultiDWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -20,9 +20,9 @@ constexpr auto SIGNATURE =
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_256_128x128x8)
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x64x1)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave);
|
||||
|
||||
|
||||
@@ -610,7 +610,7 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_<4>, ConvSpecializationBwdWeight_, GridGemm_, Prefetch_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_<4>, ConvSpecializationBwdWeight_, BlockGemm_, MultipleDSpecialization_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_<>, ConvSpecializationBwdWeight_, BlockGemm_, MultipleDSpecialization_>;
|
||||
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -449,7 +449,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_W
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<4>>(t));
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3(
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
@@ -861,7 +861,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
@@ -875,7 +875,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
@@ -900,7 +900,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
@@ -914,7 +914,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
|
||||
Reference in New Issue
Block a user