Add backward instance to the builder.

The factory still needs to be cleaned up, but since I finally have the
code building and all the test passing, this is a good time to commit.

There is a lot of generalization in this commit to handle backwards
convolutions.
This commit is contained in:
John Shumway
2025-09-17 13:17:19 +00:00
parent 0e5e514140
commit dd0318e0ab
5 changed files with 259 additions and 30 deletions

View File

@@ -3,6 +3,7 @@
// #include
// "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
#include <ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
#include <ck_tile/builder/conv_signature.hpp>
#include <ck_tile/builder/conv_algorithm.hpp>
#include <ck_tile/builder/builder_utils.hpp>
@@ -12,7 +13,7 @@
namespace ck_tile::builder {
// Type mappings from the builder GroupConvLayout enum class to the CK tensor data types.
template <GroupConvLayout Layout, int SPATIAL_DIM>
template <GroupConvLayout Layout, int SPATIAL_DIM, ConvDirection DIR>
requires(ConvSpatialDim<SPATIAL_DIM>)
struct ConvTensorLayouts
{
@@ -23,7 +24,7 @@ struct ConvTensorLayouts
};
template <>
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_FIRST, 2>
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_FIRST, 2, ConvDirection::FORWARD>
{
// Channels first convolution layout.
using ALayout = ck::tensor_layout::convolution::NHWGC;
@@ -33,7 +34,17 @@ struct ConvTensorLayouts<GroupConvLayout::CHANNELS_FIRST, 2>
};
template <>
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 2>
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 2, ConvDirection::BACKWARD_DATA>
{
// Channels last convolution layout.
using ALayout = ck::tensor_layout::convolution::NGKHW;
using BLayout = ck::tensor_layout::convolution::GKYXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NGCHW;
};
template <>
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 2, ConvDirection::FORWARD>
{
// Channels last convolution layout.
using ALayout = ck::tensor_layout::convolution::NHWGC;
@@ -43,7 +54,7 @@ struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 2>
};
template <>
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 3>
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 3, ConvDirection::FORWARD>
{
// Channels last convolution layout.
using ALayout = ck::tensor_layout::convolution::NDHWGC;
@@ -103,16 +114,23 @@ struct ConvPassThroughOps
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
};
// The specializations for the convolution and GEMM.
// The algorithm specializations for the convolution and GEMM.
template <typename CONV_ENUM>
requires(
std::is_same_v<CONV_ENUM, ck::tensor_operation::device::ConvolutionForwardSpecialization> ||
std::is_same_v<CONV_ENUM,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization>)
struct ConvSpec
{
ck::tensor_operation::device::ConvolutionForwardSpecialization conv_spec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
ck::tensor_operation::device::GemmSpecialization gemm_spec =
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
CONV_ENUM conv_spec;
ck::tensor_operation::device::GemmSpecialization gemm_spec;
};
// Block info for a convlution.
// Deduction guide for ConvSpec simplifies brace initialization.
template <typename CONV_ENUM, typename GEMM_ENUM>
ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec<CONV_ENUM>;
// Block info for a convolution.
struct ConvBlock
{
int block_size = 0;
@@ -148,10 +166,22 @@ struct ConvTuning
int n_xdl_per_wave = 0;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
constexpr ConvTuning SetConvTuningInfo()
{
using AlgorithmType = decltype(ALGORITHM);
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
// Default values for backward data if tuning info isn't specified.
return ConvTuning{
.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 16,
.n_per_dxl = 16,
.m_xdl_per_wave = 1,
.n_xdl_per_wave = 4,
};
}
if constexpr(SpecifiesConvTuning<AlgorithmType>)
{
constexpr auto& TP = ALGORITHM.tuning_params;
@@ -192,13 +222,27 @@ struct CBlockTransfer
{
int m_xdl_per_wave_per_shuffle = 0;
int n_xdl_per_wave_per_shuffle = 0;
ck::Array<int, 4> thread_cluster_dims = {0, 0, 0, 0};
ck::Array<int, 4> thread_cluster_dims = {0, 0, 0, 8};
int scaler_per_vector = 8;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
constexpr BlockTransfer SetABlockTransfer()
{
using AlgorithmType = decltype(ALGORITHM);
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
// Different default values for backward data.
return BlockTransfer{
.thread_cluster_dims = {4, 16, 1},
.thread_cluster_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
.src_vector_dim = 2,
.src_scaler_per_vector = 8,
.dest_scaler_per_vector_k1 = 8,
.add_extra = 1,
};
}
BlockTransfer block_transfer{
.thread_cluster_dims = {4, 64, 1},
.thread_cluster_order = {1, 0, 2},
@@ -208,7 +252,6 @@ constexpr BlockTransfer SetABlockTransfer()
.dest_scaler_per_vector_k1 = 8,
.add_extra = 0,
};
using AlgorithmType = decltype(ALGORITHM);
if constexpr(SpecifiesBlockATransfer<AlgorithmType>)
{
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_a;
@@ -218,9 +261,23 @@ constexpr BlockTransfer SetABlockTransfer()
return block_transfer;
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
constexpr BlockTransfer SetBBlockTransfer()
{
using AlgorithmType = decltype(ALGORITHM);
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
// Different default values for backward data.
return BlockTransfer{
.thread_cluster_dims = {4, 8, 1},
.thread_cluster_order = {0, 2, 1},
.src_access_order = {0, 2, 1},
.src_vector_dim = 1,
.src_scaler_per_vector = 8,
.dest_scaler_per_vector_k1 = 8,
.add_extra = 1,
};
}
BlockTransfer block_transfer{
.thread_cluster_dims = {4, 64, 1},
.thread_cluster_order = {1, 0, 2},
@@ -230,7 +287,6 @@ constexpr BlockTransfer SetBBlockTransfer()
.dest_scaler_per_vector_k1 = 8,
.add_extra = 0,
};
using AlgorithmType = decltype(ALGORITHM);
if constexpr(SpecifiesBlockBTransfer<AlgorithmType>)
{
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_b;
@@ -240,16 +296,26 @@ constexpr BlockTransfer SetBBlockTransfer()
return block_transfer;
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
constexpr CBlockTransfer SetCBlockTransfer()
{
using AlgorithmType = decltype(ALGORITHM);
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
// Different default values for backward data.
return CBlockTransfer{
.m_xdl_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.thread_cluster_dims = {1, 16, 1, 4},
.scaler_per_vector = 4,
};
}
CBlockTransfer block_transfer{
.m_xdl_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.thread_cluster_dims = {1, 32, 1, 8},
.scaler_per_vector = 8,
};
using AlgorithmType = decltype(ALGORITHM);
if constexpr(SpecifiesBlockCTransfer<AlgorithmType>)
{
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c;
@@ -281,26 +347,32 @@ constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
return ck::BlockGemmPipelineVersion::v4;
}
// Factory builds an instance of a grouped convolution kernel.
// Primary template for the convolution factory.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
auto Version>
requires SupportedVersion<Version>
struct ConvFactory
auto VERSION>
struct ConvFactory;
// Factory builds an instance of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
auto VERSION>
requires SupportedVersion<VERSION> && ConvDirectionIsForward<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM>;
using Types = ConvTensorTypes<SIGNATURE.data_type>;
using Ops = ConvPassThroughOps;
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = ConvTensorTypes<SIGNATURE.data_type>;
using Ops = ConvPassThroughOps;
static constexpr ConvSpec SPECIALIZATION{
.conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
};
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
static constexpr ConvTuning TUNING = SetConvTuningInfo<ALGORITHM>();
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetABlockTransfer<ALGORITHM>();
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetBBlockTransfer<ALGORITHM>();
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<ALGORITHM>();
static constexpr ConvTuning TUNING = SetConvTuningInfo<SIGNATURE, ALGORITHM>();
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetABlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetBBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto PIPELINE_VERSION = SetBlockGemmPipelineVersion<ALGORITHM>();
// The convlution kernel class instance.
@@ -354,4 +426,86 @@ struct ConvFactory
PIPELINE_VERSION>;
};
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>
// clang-format on
// Factory builds an instance of a grouped backward-data convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
auto VERSION>
requires SupportedVersion<VERSION> && ConvDirectionIsBackwardData<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::BACKWARD_DATA>;
using Types = ConvTensorTypes<SIGNATURE.data_type>;
using Ops = ConvPassThroughOps;
static constexpr ConvSpec SPECIALIZATION{
.conv_spec = ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default,
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
};
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
static constexpr ConvTuning TUNING = SetConvTuningInfo<SIGNATURE, ALGORITHM>();
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetABlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetBBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto PIPELINE_VERSION = SetBlockGemmPipelineVersion<ALGORITHM>();
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
SPECIALIZATION.conv_spec,
true, // DoPadGemmM
true, // DoPadGemmN
1, // NumGemmKPrefetchStage
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
TUNING.ak1,
TUNING.bk1,
TUNING.m_per_xdl,
TUNING.n_per_dxl,
TUNING.m_xdl_per_wave,
TUNING.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scaler_per_vector,
A_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
A_BLOCK_TRANSFER.add_extra,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scaler_per_vector,
B_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
B_BLOCK_TRANSFER.add_extra,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scaler_per_vector>;
};
} // namespace ck_tile::builder

View File

@@ -59,4 +59,14 @@ concept ValidConvSignature = requires {
requires ConvDataType<Sig.data_type>;
};
template <auto Sig>
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
template <auto Sig>
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
template <auto Sig>
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
} // namespace ck_tile::builder

View File

@@ -20,6 +20,7 @@ add_ck_builder_test(test_conv_builder
test_conv_builder.cpp
test_conv_grp_fwd_2d.cpp
test_conv_grp_fwd_3d.cpp
test_conv_grp_bwd_2d.cpp
test_conv_instances.cpp)
add_ck_builder_test(test_builder_utils test_builder_utils.cpp)

View File

@@ -0,0 +1,60 @@
#include <gtest/gtest.h>
#include <ck_tile/builder/conv_builder.hpp>
namespace {
namespace ckb = ck_tile::builder;
using P = ckb::BlockGemmPipelineVersion;
// Defines the signature of the convolution operation to be tested.
// This includes dimensionality, direction, data layout, and data type.
struct ConvSignature
{
int spatial_dim = 2;
ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA;
ckb::GroupConvLayout layout = ckb::GroupConvLayout::CHANNELS_LAST;
ckb::DataType data_type = ckb::DataType::FP16;
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
struct ConvAlgorithm
{
ckb::ThreadBlock thread_block;
// ckb::ConvTuningParams tuning_params;
// struct BlockTransfer
// {
// ckb::BlockATransferLengths thread_cluster_dims_a;
// ckb::BlockBTransferLengths thread_cluster_dims_b;
// // ckb::BlockCTransferLengths thread_cluster_dims_c;
// } block_transfer;
};
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
// static_assert(ckb::SpecifiesConvTuning<ConvAlgorithm>);
// static_assert(ckb::SpecifiesBlockATransfer<ConvAlgorithm>);
// static_assert(ckb::SpecifiesBlockBTransfer<ConvAlgorithm>);
// static_assert(ckb::SpecifiesBlockCTransfer<ConvAlgorithm>);
// Comment out this test until it compiles.
TEST(ConvBuilderGrpBwd2d, TestFirstExample)
{
// test first instance in
// include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
static constexpr const ConvSignature SIGNATURE;
static constexpr const ConvAlgorithm ALGORITHM{
.thread_block = {.block_size = 64, .submatrix = {.m = 16, .n = 64, .k = 32}},
// .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 1},
// .block_transfer = {
// .thread_cluster_dims_a = {.k0 = 4, .m = 16, .k1 = 1},
// .thread_cluster_dims_b = {.k0 = 4, .n = 8, .k1 = 1}}
};
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
EXPECT_EQ(
Builder::Instance::TypeString(),
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<64, 16, 64, 32, 8, 8, Default, 16, 16, 1, 4, 8, 8, 1, 1, "
"TransposeTransferInScalarPerVectorAligned: 1, TransposeTransferOutScalarPerVectorAligned: 1>");
}
} // namespace

View File

@@ -1761,7 +1761,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
static std::string TypeString()
{
auto str = std::stringstream();
@@ -1797,6 +1797,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return str.str();
}
std::string GetTypeString() const override {
return TypeString();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);