mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
Add bwd weight XDL CShuffle V3 factory.
This commit is contained in:
@@ -29,18 +29,20 @@ concept OutputVectorTransferLimits = requires {
|
||||
|
||||
// Limits for access order. Must be a permutation of {0, 1, 2}.
|
||||
template <auto Value>
|
||||
concept AccessOrderLimits = requires {
|
||||
concept AccessOrderLimits3D = requires {
|
||||
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) &&
|
||||
(Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) &&
|
||||
(Value[2] >= 0 && Value[2] < 3));
|
||||
(Value[2] >= 0 && Value[2] < 3) && (Value.Size() == 3));
|
||||
};
|
||||
|
||||
// Limits for access order. Must be a permutation of {1, 2, 3} for the last three elements.
|
||||
// Limits for access order. Must be a permutation of {0, 1, 2, 3}.
|
||||
template <auto Value>
|
||||
concept BwdAccessOrderLimits = requires {
|
||||
requires((Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) &&
|
||||
(Value[1] >= 1 && Value[1] < 4) && (Value[2] >= 1 && Value[2] < 4) &&
|
||||
(Value[3] >= 1 && Value[3] < 4)) && (Value[0] == 0);
|
||||
concept AccessOrderLimits4D = requires {
|
||||
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[0] != Value[3]) &&
|
||||
(Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) &&
|
||||
(Value[0] >= 0 && Value[0] < 4) && (Value[1] >= 0 && Value[1] < 4) &&
|
||||
(Value[2] >= 0 && Value[2] < 4) && (Value[3] >= 0 && Value[3] < 4) &&
|
||||
(Value.Size() == 4));
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -279,6 +279,47 @@ struct BwdXdlAlgorithm {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BwdXdlV3Algorithm {
|
||||
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadBlock)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockTransferBwd)
|
||||
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm)
|
||||
CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockGemm)
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesThreadBlock;
|
||||
static constexpr bool c3 = c_SpecifiesBlockTransferBwd;
|
||||
static constexpr bool c4 = c_SpecifiesLdsTransfer;
|
||||
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
|
||||
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
|
||||
static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm;
|
||||
static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization;
|
||||
static constexpr bool c9 = c_SpecifiesBlockGemm;
|
||||
|
||||
static consteval bool is_valid() {
|
||||
return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9;
|
||||
}
|
||||
|
||||
static consteval auto message() -> std::string {
|
||||
return std::string("\n=== Backward XDL V3 Algorithm Diagnostic (closest match) ===\n"
|
||||
"Concepts for BwdXdlV3 Algorithm:\n") +
|
||||
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) +
|
||||
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockGemm);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
consteval int count_matches_fwd_xdl_v3() {
|
||||
using Alg = FwdXdlV3Algorithm<T>;
|
||||
@@ -309,6 +350,12 @@ consteval int count_matches_bwd_xdl() {
|
||||
return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
consteval int count_matches_bwd_xdl_v3() {
|
||||
using Alg = BwdXdlV3Algorithm<T>;
|
||||
return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
consteval int count_matches_large_tensor() {
|
||||
using Alg = LargeTensorAlgorithm<T>;
|
||||
@@ -368,12 +415,20 @@ consteval void diagnose_fwd_algorithm_signature()
|
||||
template <typename AlgoType>
|
||||
consteval void diagnose_bwd_weight_algorithm_signature()
|
||||
{
|
||||
constexpr int xdl_matches = count_matches_fwd_xdl<AlgoType>();
|
||||
constexpr int max_matches = xdl_matches;
|
||||
constexpr int xdl_matches = count_matches_bwd_xdl<AlgoType>();
|
||||
constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3<AlgoType>();
|
||||
|
||||
constexpr int max_matches = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches;
|
||||
|
||||
if constexpr (max_matches == xdl_matches) {
|
||||
using Alg = BwdXdlAlgorithm<AlgoType>;
|
||||
static_assert(Alg::is_valid(), Alg::message());
|
||||
} else {
|
||||
}
|
||||
else if constexpr (max_matches == xdl_v3_matches) {
|
||||
using Alg = BwdXdlV3Algorithm<AlgoType>;
|
||||
static_assert(Alg::is_valid(), Alg::message());
|
||||
}
|
||||
else {
|
||||
// This should never happen
|
||||
static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics.");
|
||||
}
|
||||
|
||||
@@ -47,10 +47,10 @@ struct ConvBwdWeightXdlFactory
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(BwdAccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(BwdAccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(BwdAccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(BwdAccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>, "Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>, "Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>, "Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>, "Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::InComputeType,
|
||||
typename Types::WeiComputeType>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -49,6 +49,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// Disable pragma message warnings for factory selection diagnostics
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-W#pragma-messages"
|
||||
#endif
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
@@ -64,6 +70,7 @@
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
@@ -96,28 +103,34 @@ constexpr auto make_conv_instance()
|
||||
// CK Tile supports common factory for each direction
|
||||
if constexpr(TileAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvTileFactory...")
|
||||
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
if constexpr(FwdXdlV3Algorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvFwdXdlV3Factory...")
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(FwdXdlAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvFwdXdlFactory...")
|
||||
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(FwdWmmaAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvFwdWmmaFactory...")
|
||||
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(FwdDlAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvFwdDlFactory...")
|
||||
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(LargeTensorAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvFwdLargeTensorFactory...")
|
||||
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else
|
||||
@@ -135,8 +148,14 @@ constexpr auto make_conv_instance()
|
||||
{
|
||||
if constexpr (BwdXdlAlgorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvBwdWeightXdlFactory...")
|
||||
return typename ConvBwdWeightXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr (BwdXdlV3Algorithm<AlgoType>::is_valid())
|
||||
{
|
||||
#pragma message("[CK Builder] Using ConvBwdWeightXdlV3Factory...")
|
||||
return typename ConvBwdWeightXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else
|
||||
{
|
||||
diagnose_bwd_weight_algorithm_signature<AlgoType>();
|
||||
@@ -152,3 +171,8 @@ constexpr auto make_conv_instance()
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
|
||||
// Re-enable pragma message warnings
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
|
||||
@@ -54,10 +54,10 @@ struct ConvFwdLargeTensorFactory
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance with large tensor support.
|
||||
using Instance =
|
||||
|
||||
@@ -56,10 +56,10 @@ struct ConvFwdXdlV3Factory
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
|
||||
@@ -52,10 +52,10 @@ struct ConvFwdWmmaFactory
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
||||
|
||||
@@ -51,10 +51,10 @@ struct ConvFwdXdlFactory
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
|
||||
@@ -62,16 +62,40 @@ constexpr BwdBlockTransfer SetBwdConvBlockTransfer()
|
||||
auto& src_order = TRANSFER.src_access_order;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
|
||||
return BwdBlockTransfer{
|
||||
.thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]},
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = lds_cfg.is_direct_load,
|
||||
.lds_padding = lds_cfg.lds_padding,
|
||||
};
|
||||
constexpr auto array_length = block_order.order.size();
|
||||
static_assert(block_order.order.size() == src_order.order.size(),
|
||||
"Mismatched size between block order and src order");
|
||||
|
||||
if constexpr (array_length == 3)
|
||||
{
|
||||
return BwdBlockTransfer{
|
||||
.thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = lds_cfg.is_direct_load,
|
||||
.lds_padding = lds_cfg.lds_padding,
|
||||
};
|
||||
}
|
||||
else if constexpr (array_length == 4)
|
||||
{
|
||||
return BwdBlockTransfer{
|
||||
.thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]},
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = lds_cfg.is_direct_load,
|
||||
.lds_padding = lds_cfg.lds_padding,
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Internal error: Unsupported array length");
|
||||
}
|
||||
}
|
||||
|
||||
// Block transfer parameters for C tensor.
|
||||
|
||||
@@ -142,6 +142,7 @@ target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_bwd_weight_instances
|
||||
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
|
||||
)
|
||||
target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility)
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 1,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NGCW}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_1DBf16_CShuffle_V3, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"FILTER_1X1_STRIDE1_PAD0",
|
||||
"NGCW,GKXC,NGKW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave",
|
||||
"v2"});
|
||||
}
|
||||
@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
|
||||
.accumulation_data_type = INT32,
|
||||
.input = {.config = {.layout = GNWC}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = GNWK}}};
|
||||
.output = {.config = {.layout = GNWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
|
||||
|
||||
@@ -74,7 +74,7 @@ struct BlockGemm
|
||||
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
|
||||
|
||||
// Describe Aand B block transfer thread cluster lengths.
|
||||
template <bool IsBwd = false>
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct BlockTransfer
|
||||
{
|
||||
size_t k0;
|
||||
@@ -83,16 +83,16 @@ struct BlockTransfer
|
||||
size_t k_batch_size;
|
||||
};
|
||||
|
||||
// Specialization for forward (IsBwd = false)
|
||||
// Specialization for ThreadSliceLength == 3
|
||||
template <>
|
||||
struct BlockTransfer<false>
|
||||
struct BlockTransfer<3>
|
||||
{
|
||||
size_t k0;
|
||||
size_t m_n;
|
||||
size_t k1;
|
||||
};
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<>>);
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<true>>);
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<4>>);
|
||||
|
||||
// Describe C block transfer thread cluster lengths.
|
||||
struct ThreadCluster
|
||||
@@ -130,13 +130,13 @@ struct AccessOrder
|
||||
static_assert(AccessOrderDescriptor<AccessOrder<>>);
|
||||
static_assert(AccessOrderDescriptor<AccessOrder<4>>);
|
||||
|
||||
template <bool IsBwd = false>
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct InputTransfer
|
||||
{
|
||||
BlockTransfer<IsBwd> block_transfer;
|
||||
BlockTransfer<ThreadSliceLength> block_transfer;
|
||||
LdsTransfer lds_transfer;
|
||||
std::conditional_t<IsBwd, AccessOrder<4>, AccessOrder<3>> block_transfer_access_order;
|
||||
std::conditional_t<IsBwd, AccessOrder<4>, AccessOrder<3>> src_access_order;
|
||||
AccessOrder<ThreadSliceLength> block_transfer_access_order;
|
||||
AccessOrder<ThreadSliceLength> src_access_order;
|
||||
};
|
||||
|
||||
struct OutputTransfer
|
||||
@@ -145,11 +145,11 @@ struct OutputTransfer
|
||||
Epilogue epilogue;
|
||||
};
|
||||
|
||||
template <bool IsBwd = false>
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct Transfer
|
||||
{
|
||||
InputTransfer<IsBwd> a;
|
||||
InputTransfer<IsBwd> b;
|
||||
InputTransfer<ThreadSliceLength> a;
|
||||
InputTransfer<ThreadSliceLength> b;
|
||||
OutputTransfer c;
|
||||
};
|
||||
|
||||
@@ -213,10 +213,10 @@ struct WmmaGemm_
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
template <bool IsBwd = false>
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct Transfer_
|
||||
{
|
||||
Transfer<IsBwd> transfer;
|
||||
Transfer<ThreadSliceLength> transfer;
|
||||
};
|
||||
|
||||
struct ConvSpecializationFwd_
|
||||
@@ -397,7 +397,7 @@ struct ConvAlgorithmTemplate : Components...
|
||||
constexpr auto with_transfer(const T& t) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<Transfer_<>, ConvAlgorithmTemplate> ||
|
||||
std::is_base_of_v<Transfer_<true>, ConvAlgorithmTemplate>);
|
||||
std::is_base_of_v<Transfer_<4>, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.transfer = t;
|
||||
return result;
|
||||
@@ -553,6 +553,9 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileTh
|
||||
TileOptimizations_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<true>, ConvSpecializationBwdWeight_, TransposeParams_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<4>, ConvSpecializationBwdWeight_, TransposeParams_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, BwdXdlGemm_, Transfer_<>, ConvSpecializationBwdWeight_, BlockGemm_>;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -72,8 +72,7 @@ constexpr Transfer<> Transfer_4x64x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr bool BWD = true;
|
||||
constexpr Transfer<BWD> BwdTransfer_4x64x1{
|
||||
constexpr Transfer<4> BwdTransfer_4x64x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
@@ -106,6 +105,39 @@ constexpr Transfer<BWD> BwdTransfer_4x64x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 2},
|
||||
},
|
||||
};
|
||||
|
||||
constexpr Transfer<> Transfer_4x64x1_fp8{
|
||||
.a =
|
||||
{
|
||||
@@ -210,6 +242,10 @@ constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}};
|
||||
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8, .bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
@@ -251,6 +287,9 @@ constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256,
|
||||
constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock ThreadBlock_64_32x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 32, .n = 32, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
|
||||
@@ -120,17 +120,21 @@ inline std::string to_string<BlockGemm>(BlockGemm t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <bool IsBwd>
|
||||
inline std::string to_string(BlockTransfer<IsBwd> t)
|
||||
template <size_t ThreadSliceDim>
|
||||
inline std::string to_string(BlockTransfer<ThreadSliceDim> t)
|
||||
{
|
||||
if constexpr (IsBwd)
|
||||
if constexpr (ThreadSliceDim == 4)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 4>{t.k_batch_size, t.k0, t.m_n, t.k1});
|
||||
}
|
||||
else
|
||||
else if constexpr (ThreadSliceDim == 3)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadSliceDim == 3 || ThreadSliceDim == 4, "Unsupported ThreadSliceDim");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -156,8 +160,8 @@ inline std::string to_string(AccessOrder<N> t)
|
||||
return array_to_seq(t.order);
|
||||
}
|
||||
|
||||
template <bool IsBwd>
|
||||
inline std::string to_string(InputTransfer<IsBwd> t)
|
||||
template <size_t N = 3>
|
||||
inline std::string to_string(InputTransfer<N> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << ","
|
||||
@@ -176,8 +180,8 @@ inline std::string to_string<OutputTransfer>(OutputTransfer t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <bool IsBwd>
|
||||
inline std::string to_string(Transfer<IsBwd> t)
|
||||
template <size_t N = 3>
|
||||
inline std::string to_string(Transfer<N> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
@@ -267,8 +271,8 @@ inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <bool IsBwd>
|
||||
inline std::string to_string(Transfer_<IsBwd> t)
|
||||
template <size_t ThreadSliceDim = 3>
|
||||
inline std::string to_string(Transfer_<ThreadSliceDim> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
@@ -378,9 +382,18 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuff
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
constexpr bool BWD = true;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<BWD>>(t));
|
||||
<< "," << to_string(static_cast<Transfer_<4>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user