mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Add backward instance to the builder.
The factory still needs to be cleaned up, but since I finally have the code building and all the test passing, this is a good time to commit. There is a lot of generalization in this commit to handle backwards convolutions.
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
// #include
|
||||
// "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
|
||||
#include <ck_tile/builder/conv_signature.hpp>
|
||||
#include <ck_tile/builder/conv_algorithm.hpp>
|
||||
#include <ck_tile/builder/builder_utils.hpp>
|
||||
@@ -12,7 +13,7 @@
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Type mappings from the builder GroupConvLayout enum class to the CK tensor data types.
|
||||
template <GroupConvLayout Layout, int SPATIAL_DIM>
|
||||
template <GroupConvLayout Layout, int SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
@@ -23,7 +24,7 @@ struct ConvTensorLayouts
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_FIRST, 2>
|
||||
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_FIRST, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
// Channels first convolution layout.
|
||||
using ALayout = ck::tensor_layout::convolution::NHWGC;
|
||||
@@ -33,7 +34,17 @@ struct ConvTensorLayouts<GroupConvLayout::CHANNELS_FIRST, 2>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 2>
|
||||
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 2, ConvDirection::BACKWARD_DATA>
|
||||
{
|
||||
// Channels last convolution layout.
|
||||
using ALayout = ck::tensor_layout::convolution::NGKHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGCHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
// Channels last convolution layout.
|
||||
using ALayout = ck::tensor_layout::convolution::NHWGC;
|
||||
@@ -43,7 +54,7 @@ struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 2>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 3>
|
||||
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
// Channels last convolution layout.
|
||||
using ALayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
@@ -103,16 +114,23 @@ struct ConvPassThroughOps
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
// The specializations for the convolution and GEMM.
|
||||
// The algorithm specializations for the convolution and GEMM.
|
||||
template <typename CONV_ENUM>
|
||||
requires(
|
||||
std::is_same_v<CONV_ENUM, ck::tensor_operation::device::ConvolutionForwardSpecialization> ||
|
||||
std::is_same_v<CONV_ENUM,
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization>)
|
||||
struct ConvSpec
|
||||
{
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization conv_spec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
ck::tensor_operation::device::GemmSpecialization gemm_spec =
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
CONV_ENUM conv_spec;
|
||||
ck::tensor_operation::device::GemmSpecialization gemm_spec;
|
||||
};
|
||||
|
||||
// Block info for a convlution.
|
||||
// Deduction guide for ConvSpec simplifies brace initialization.
|
||||
template <typename CONV_ENUM, typename GEMM_ENUM>
|
||||
ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec<CONV_ENUM>;
|
||||
|
||||
// Block info for a convolution.
|
||||
struct ConvBlock
|
||||
{
|
||||
int block_size = 0;
|
||||
@@ -148,10 +166,22 @@ struct ConvTuning
|
||||
int n_xdl_per_wave = 0;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr ConvTuning SetConvTuningInfo()
|
||||
{
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
// Default values for backward data if tuning info isn't specified.
|
||||
return ConvTuning{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 16,
|
||||
.n_per_dxl = 16,
|
||||
.m_xdl_per_wave = 1,
|
||||
.n_xdl_per_wave = 4,
|
||||
};
|
||||
}
|
||||
if constexpr(SpecifiesConvTuning<AlgorithmType>)
|
||||
{
|
||||
constexpr auto& TP = ALGORITHM.tuning_params;
|
||||
@@ -192,13 +222,27 @@ struct CBlockTransfer
|
||||
{
|
||||
int m_xdl_per_wave_per_shuffle = 0;
|
||||
int n_xdl_per_wave_per_shuffle = 0;
|
||||
ck::Array<int, 4> thread_cluster_dims = {0, 0, 0, 0};
|
||||
ck::Array<int, 4> thread_cluster_dims = {0, 0, 0, 8};
|
||||
int scaler_per_vector = 8;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetABlockTransfer()
|
||||
{
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
// Different default values for backward data.
|
||||
return BlockTransfer{
|
||||
.thread_cluster_dims = {4, 16, 1},
|
||||
.thread_cluster_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.src_vector_dim = 2,
|
||||
.src_scaler_per_vector = 8,
|
||||
.dest_scaler_per_vector_k1 = 8,
|
||||
.add_extra = 1,
|
||||
};
|
||||
}
|
||||
BlockTransfer block_transfer{
|
||||
.thread_cluster_dims = {4, 64, 1},
|
||||
.thread_cluster_order = {1, 0, 2},
|
||||
@@ -208,7 +252,6 @@ constexpr BlockTransfer SetABlockTransfer()
|
||||
.dest_scaler_per_vector_k1 = 8,
|
||||
.add_extra = 0,
|
||||
};
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(SpecifiesBlockATransfer<AlgorithmType>)
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_a;
|
||||
@@ -218,9 +261,23 @@ constexpr BlockTransfer SetABlockTransfer()
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetBBlockTransfer()
|
||||
{
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
// Different default values for backward data.
|
||||
return BlockTransfer{
|
||||
.thread_cluster_dims = {4, 8, 1},
|
||||
.thread_cluster_order = {0, 2, 1},
|
||||
.src_access_order = {0, 2, 1},
|
||||
.src_vector_dim = 1,
|
||||
.src_scaler_per_vector = 8,
|
||||
.dest_scaler_per_vector_k1 = 8,
|
||||
.add_extra = 1,
|
||||
};
|
||||
}
|
||||
BlockTransfer block_transfer{
|
||||
.thread_cluster_dims = {4, 64, 1},
|
||||
.thread_cluster_order = {1, 0, 2},
|
||||
@@ -230,7 +287,6 @@ constexpr BlockTransfer SetBBlockTransfer()
|
||||
.dest_scaler_per_vector_k1 = 8,
|
||||
.add_extra = 0,
|
||||
};
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(SpecifiesBlockBTransfer<AlgorithmType>)
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_b;
|
||||
@@ -240,16 +296,26 @@ constexpr BlockTransfer SetBBlockTransfer()
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr CBlockTransfer SetCBlockTransfer()
|
||||
{
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
// Different default values for backward data.
|
||||
return CBlockTransfer{
|
||||
.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.thread_cluster_dims = {1, 16, 1, 4},
|
||||
.scaler_per_vector = 4,
|
||||
};
|
||||
}
|
||||
CBlockTransfer block_transfer{
|
||||
.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.thread_cluster_dims = {1, 32, 1, 8},
|
||||
.scaler_per_vector = 8,
|
||||
};
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(SpecifiesBlockCTransfer<AlgorithmType>)
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c;
|
||||
@@ -281,26 +347,32 @@ constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
return ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
|
||||
// Factory builds an instance of a grouped convolution kernel.
|
||||
// Primary template for the convolution factory.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
auto Version>
|
||||
requires SupportedVersion<Version>
|
||||
struct ConvFactory
|
||||
auto VERSION>
|
||||
struct ConvFactory;
|
||||
|
||||
// Factory builds an instance of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
auto VERSION>
|
||||
requires SupportedVersion<VERSION> && ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM>;
|
||||
using Types = ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = ConvPassThroughOps;
|
||||
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = ConvPassThroughOps;
|
||||
static constexpr ConvSpec SPECIALIZATION{
|
||||
.conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
};
|
||||
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr ConvTuning TUNING = SetConvTuningInfo<ALGORITHM>();
|
||||
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetABlockTransfer<ALGORITHM>();
|
||||
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetBBlockTransfer<ALGORITHM>();
|
||||
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<ALGORITHM>();
|
||||
static constexpr ConvTuning TUNING = SetConvTuningInfo<SIGNATURE, ALGORITHM>();
|
||||
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetABlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetBBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION = SetBlockGemmPipelineVersion<ALGORITHM>();
|
||||
// The convlution kernel class instance.
|
||||
@@ -354,4 +426,86 @@ struct ConvFactory
|
||||
PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>
|
||||
// clang-format on
|
||||
|
||||
// Factory builds an instance of a grouped backward-data convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
auto VERSION>
|
||||
requires SupportedVersion<VERSION> && ConvDirectionIsBackwardData<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::BACKWARD_DATA>;
|
||||
using Types = ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = ConvPassThroughOps;
|
||||
static constexpr ConvSpec SPECIALIZATION{
|
||||
.conv_spec = ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default,
|
||||
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
};
|
||||
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr ConvTuning TUNING = SetConvTuningInfo<SIGNATURE, ALGORITHM>();
|
||||
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetABlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetBBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION = SetBlockGemmPipelineVersion<ALGORITHM>();
|
||||
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
true, // DoPadGemmM
|
||||
true, // DoPadGemmN
|
||||
1, // NumGemmKPrefetchStage
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
TUNING.ak1,
|
||||
TUNING.bk1,
|
||||
TUNING.m_per_xdl,
|
||||
TUNING.n_per_dxl,
|
||||
TUNING.m_xdl_per_wave,
|
||||
TUNING.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scaler_per_vector,
|
||||
A_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
|
||||
A_BLOCK_TRANSFER.add_extra,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scaler_per_vector,
|
||||
B_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
|
||||
B_BLOCK_TRANSFER.add_extra,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scaler_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -59,4 +59,14 @@ concept ValidConvSignature = requires {
|
||||
requires ConvDataType<Sig.data_type>;
|
||||
};
|
||||
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
|
||||
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
|
||||
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -20,6 +20,7 @@ add_ck_builder_test(test_conv_builder
|
||||
test_conv_builder.cpp
|
||||
test_conv_grp_fwd_2d.cpp
|
||||
test_conv_grp_fwd_3d.cpp
|
||||
test_conv_grp_bwd_2d.cpp
|
||||
test_conv_instances.cpp)
|
||||
|
||||
add_ck_builder_test(test_builder_utils test_builder_utils.cpp)
|
||||
60
experimental/builder/test/test_conv_grp_bwd_2d.cpp
Normal file
60
experimental/builder/test/test_conv_grp_bwd_2d.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ck_tile/builder/conv_builder.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
using P = ckb::BlockGemmPipelineVersion;
|
||||
|
||||
// Defines the signature of the convolution operation to be tested.
|
||||
// This includes dimensionality, direction, data layout, and data type.
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout::CHANNELS_LAST;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
struct ConvAlgorithm
|
||||
{
|
||||
ckb::ThreadBlock thread_block;
|
||||
// ckb::ConvTuningParams tuning_params;
|
||||
// struct BlockTransfer
|
||||
// {
|
||||
// ckb::BlockATransferLengths thread_cluster_dims_a;
|
||||
// ckb::BlockBTransferLengths thread_cluster_dims_b;
|
||||
// // ckb::BlockCTransferLengths thread_cluster_dims_c;
|
||||
// } block_transfer;
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
|
||||
// static_assert(ckb::SpecifiesConvTuning<ConvAlgorithm>);
|
||||
// static_assert(ckb::SpecifiesBlockATransfer<ConvAlgorithm>);
|
||||
// static_assert(ckb::SpecifiesBlockBTransfer<ConvAlgorithm>);
|
||||
// static_assert(ckb::SpecifiesBlockCTransfer<ConvAlgorithm>);
|
||||
|
||||
// Comment out this test until it compiles.
|
||||
|
||||
TEST(ConvBuilderGrpBwd2d, TestFirstExample)
|
||||
{
|
||||
// test first instance in
|
||||
// include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
static constexpr const ConvAlgorithm ALGORITHM{
|
||||
.thread_block = {.block_size = 64, .submatrix = {.m = 16, .n = 64, .k = 32}},
|
||||
// .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 1},
|
||||
// .block_transfer = {
|
||||
// .thread_cluster_dims_a = {.k0 = 4, .m = 16, .k1 = 1},
|
||||
// .thread_cluster_dims_b = {.k0 = 4, .n = 8, .k1 = 1}}
|
||||
};
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
EXPECT_EQ(
|
||||
Builder::Instance::TypeString(),
|
||||
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<64, 16, 64, 32, 8, 8, Default, 16, 16, 1, 4, 8, 8, 1, 1, "
|
||||
"TransposeTransferInScalarPerVectorAligned: 1, TransposeTransferOutScalarPerVectorAligned: 1>");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -1761,7 +1761,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
static std::string TypeString()
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
@@ -1797,6 +1797,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
return str.str();
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override {
|
||||
return TypeString();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
Reference in New Issue
Block a user