Grouped conv backward data GKCYX support (#2029)

* Grouped conv backward data GKCYX support

* profiler

* Converter

* split instances
This commit is contained in:
Bartłomiej Kocot
2025-04-01 22:24:38 +02:00
committed by GitHub
parent ec742908bd
commit 8c0ab61ece
37 changed files with 1286 additions and 198 deletions

View File

@@ -243,15 +243,21 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static constexpr auto I3 = Number<3>{};
using ALayoutAfterTranspose =
std::conditional_t<is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>(),
std::conditional_t<is_NGCHW_NGKHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NHWGK,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>(),
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NDHWGK,
ALayout>>;
using BLayoutAfterTranspose =
std::conditional_t<is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::GKYXC,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::GKZYXC,
BLayout>>;
using ELayoutAfterTranspose =
std::conditional_t<is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>(),
std::conditional_t<is_NGCHW_NGKHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NHWGC,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>(),
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NDHWGC,
ELayout>>;
@@ -265,7 +271,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DoPadGemmM,
DoPadGemmN,
ALayoutAfterTranspose,
BLayout,
BLayoutAfterTranspose,
ELayoutAfterTranspose,
true, /*SplitConvN*/
ABDataType,
@@ -392,7 +398,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, MPerBlock>;
using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, MPerBlock>;
using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
static constexpr index_t ClusterLengthMPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
@@ -418,6 +425,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
using GKCYXTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
using GKYXCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
static constexpr index_t ElementwiseBlocksize = ClusterLengthMPerBlock * ClusterLengthNPerBlock;
@@ -426,7 +439,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
Tuple<NHWGCTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
Block2TileMapInOutElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
@@ -439,12 +452,30 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
I1,
I0>;
using GridwiseElementwiseWeightTranspose =
GridwiseElementwise<Tuple<GKCYXTransposeDescType>,
Tuple<GKYXCTransposeDescType>,
Tuple<const BDataType*>,
Tuple<BDataType*>,
Block2TileMapWeiElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
MPerBlock,
NPerBlock,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<1>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I0,
I1>;
using GridwiseElementwiseOutputTranspose =
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
Tuple<NGCHWTransposeDescType>,
Tuple<const EDataType*>,
Tuple<EDataType*>,
Block2TileMapElementwise,
Block2TileMapInOutElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
@@ -498,6 +529,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths,
e_g_n_c_wis_strides);
@@ -584,7 +618,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides_transposed,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
b_g_k_c_xs_strides_transposed,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides_transposed,
conv_filter_strides,
@@ -618,7 +652,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DoPadGemmM,
DoPadGemmN,
ALayoutAfterTranspose,
BLayout,
BLayoutAfterTranspose,
DLayout,
true, /*SplitConvN*/
ABDataType,
@@ -627,7 +661,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides_transposed,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
b_g_k_c_xs_strides_transposed,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
@@ -682,7 +716,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0];
compute_ptr_offset_of_n_.BatchStrideA_ =
@@ -692,8 +726,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
{
// Use not modified base strides
a_in_transpose_desc_ =
@@ -703,6 +737,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_);
b_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
b_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
@@ -710,9 +751,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapInOutElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapWeiElementwise{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapInOutElementwise{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
compute_ptr_offset_of_workspace_n_.BatchStrideA_ =
@@ -724,25 +767,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::size_t GetWorkspaceATensorSizeBytes() const
{
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(ADataType) * a_acum;
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
// Align to 128B
return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128;
}
else
{
@@ -750,6 +781,43 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
{
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
// Align to 128B
return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceSizeBytes() const
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
}
void Print() const
{
for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++)
@@ -796,11 +864,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// block-to-e-tile map
std::vector<Block2ETileMap> block_2_etile_map_container_;
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_e_;
Block2TileMapWeiElementwise elementwise_block_2_ctile_map_transpose_b_;
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
GKCYXTransposeDescType b_in_transpose_desc_;
GKYXCTransposeDescType b_out_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
@@ -835,14 +906,24 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const index_t gdz = arg.num_workgroups_per_Conv_N_;
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
p_e_grid =
type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
{
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
}
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
@@ -888,7 +969,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
dim3(BlockSize),
0,
p_a_grid,
arg.p_b_grid_,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
arg.a_element_op_,
@@ -925,11 +1006,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.Print();
}
// Transpose from NGKHW to NHWGK
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
{
EDataType* p_e_in_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
EDataType* p_e_in_grid =
type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
const auto clear_workspace = [&]() {
hip_check_error(hipMemsetAsync(p_e_in_grid,
@@ -938,47 +1021,72 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
stream_config.stream_id_));
};
const index_t grid_size =
const index_t a_grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_) *
arg.num_workgroups_per_Conv_N_;
const index_t b_grid_size =
(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
arg.b_in_transpose_desc_)
: 0; // Dont run transpose B if not needed
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
auto kernel_transpose =
kernel_batched_elementwise<GridwiseElementwiseInputTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
I1,
I1>;
kernel_elementwise_batched_dual<GridwiseElementwiseInputTranspose,
GridwiseElementwiseWeightTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<GKCYXTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<GKYXCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<const BDataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<BDataType*>,
Block2TileMapInOutElementwise,
Block2TileMapWeiElementwise,
element_wise::PassThrough,
I1,
I1,
I1,
I1>;
ave_time += launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel_transpose,
dim3(grid_size),
dim3(a_grid_size + b_grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.b_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.b_out_transpose_desc_),
make_tuple(arg.p_a_grid_),
make_tuple(arg.p_b_grid_),
make_tuple(p_a_out_grid),
make_tuple(p_b_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
arg.elementwise_block_2_ctile_map_transpose_b_,
element_wise::PassThrough{},
a_grid_size,
arg.num_workgroups_per_Conv_N_,
I1, // B is not splited per N
std::array<index_t, I1>{
static_cast<index_t>(arg.compute_ptr_offset_of_workspace_n_.BatchStrideA_)},
std::array<index_t, I1>{0},
std::array<index_t, I1>{
static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideA_)});
static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideA_)},
std::array<index_t, I1>{0});
}
ave_time += RunGemm(arg, stream_config);
// Transpose from NHWGC to NGCHW
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
@@ -987,7 +1095,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const EDataType* p_e_in_grid =
type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
EDataType* p_e_out_grid = arg.p_e_grid_;
@@ -997,7 +1106,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<const EDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
Block2TileMapInOutElementwise,
element_wise::PassThrough,
I1,
I1>;
@@ -1077,7 +1186,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector load for B matrix from global memory to LDS
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
is_same_v<BLayout, tensor_layout::convolution::GKZYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKCYX> ||
is_same_v<BLayout, tensor_layout::convolution::GKCZYX>)
{
if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0))
{
@@ -1152,8 +1263,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
{
if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
@@ -1320,8 +1431,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle;
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>()) {
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>()) {
str << ", TransposeTransferInScalarPerVectorAligned: "
<< TransposeTransferInScalarPerVectorAligned <<", "
<< "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned;

View File

@@ -93,6 +93,119 @@ __global__ void
}
}
template <typename GridwiseElementwiseFunctorA,
typename GridwiseElementwiseFunctorB,
typename InAGridDescTuple,
typename InBGridDescTuple,
typename OutAGridDescTuple,
typename OutBGridDescTuple,
typename InADataTypePointerTuple,
typename InBDataTypePointerTuple,
typename OutADataTypePointerTuple,
typename OutBDataTypePointerTuple,
typename Block2TileMapA,
typename Block2TileMapB,
typename ElementwiseOperation,
index_t NumInputsA,
index_t NumInputsB,
index_t NumOutputsA,
index_t NumOutputsB>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_elementwise_batched_dual(
const InAGridDescTuple in_grid_desc_tuple_a,
const InBGridDescTuple in_grid_desc_tuple_b,
const OutAGridDescTuple out_grid_desc_tuple_a,
const OutBGridDescTuple out_grid_desc_tuple_b,
const InADataTypePointerTuple p_in_global_tuple_a,
const InBDataTypePointerTuple p_in_global_tuple_b,
const OutADataTypePointerTuple p_out_global_tuple_a,
const OutBDataTypePointerTuple p_out_global_tuple_b,
const Block2TileMapA block_2_tile_map_a,
const Block2TileMapB block_2_tile_map_b,
const ElementwiseOperation elementwise_op,
const index_t a_grid_size,
const index_t batch_count_a,
const index_t batch_count_b,
const std::array<index_t, NumInputsA> input_batch_strides_a,
const std::array<index_t, NumInputsB> input_batch_strides_b,
const std::array<index_t, NumOutputsA> output_batch_strides_a,
const std::array<index_t, NumOutputsB> output_batch_strides_b)
{
static_assert(InAGridDescTuple::Size() == NumInputsA &&
InADataTypePointerTuple::Size() == NumInputsA);
static_assert(OutAGridDescTuple::Size() == NumOutputsA &&
OutADataTypePointerTuple::Size() == NumOutputsA);
static_assert(InBGridDescTuple::Size() == NumInputsB &&
InBDataTypePointerTuple::Size() == NumInputsB);
static_assert(OutBGridDescTuple::Size() == NumOutputsB &&
OutBDataTypePointerTuple::Size() == NumOutputsB);
const index_t block_id = __builtin_amdgcn_readfirstlane(get_block_1d_id());
if(block_id < a_grid_size)
{
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(a_grid_size / batch_count_a);
const index_t g_idx = __builtin_amdgcn_readfirstlane(block_id / num_blocks_per_batch);
InADataTypePointerTuple p_in_global_with_offset_tuple;
OutADataTypePointerTuple p_out_global_with_offset_tuple;
static_for<0, InADataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_in_global_with_offset_tuple(i) =
p_in_global_tuple_a.At(i) +
type_convert<long_index_t>(input_batch_strides_a[i]) * g_idx;
});
static_for<0, OutADataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_out_global_with_offset_tuple(i) =
p_out_global_tuple_a.At(i) +
type_convert<long_index_t>(output_batch_strides_a[i]) * g_idx;
});
GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
out_grid_desc_tuple_a,
p_in_global_with_offset_tuple,
p_out_global_with_offset_tuple,
block_2_tile_map_a,
elementwise_op,
block_id);
}
else
{
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane((get_grid_size() - a_grid_size) / batch_count_b);
const index_t g_idx =
__builtin_amdgcn_readfirstlane((block_id - a_grid_size) / num_blocks_per_batch);
InBDataTypePointerTuple p_in_global_with_offset_tuple;
OutBDataTypePointerTuple p_out_global_with_offset_tuple;
static_for<0, InBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_in_global_with_offset_tuple(i) =
p_in_global_tuple_b.At(i) +
type_convert<long_index_t>(input_batch_strides_b[i]) * g_idx;
});
static_for<0, OutBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_out_global_with_offset_tuple(i) =
p_out_global_tuple_b.At(i) +
type_convert<long_index_t>(output_batch_strides_b[i]) * g_idx;
});
GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
out_grid_desc_tuple_b,
p_in_global_with_offset_tuple,
p_out_global_with_offset_tuple,
block_2_tile_map_b,
elementwise_op,
block_id - a_grid_size);
}
}
template <typename GridwiseElementwiseFunctor,
typename InGridDescTuple,
typename OutGridDescTuple,