Support large batch tensors in grouped conv bwd data (#1711)

* Support large batch tensors in grouped conv bwd data

* Fix multiD

* fixes

* fixes

* fixes
This commit is contained in:
Bartłomiej Kocot
2024-12-06 10:55:23 +01:00
committed by GitHub
parent 58e7f37fc8
commit 261f1759de
6 changed files with 1081 additions and 857 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -106,89 +106,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static constexpr auto I3 = Number<3>{};
static constexpr index_t KPerBlock = K0PerBlock * K1;
static constexpr auto transform_conv_to_gemm =
TransformConvBwdDataToGemm_v1<NDimSpatial,
ConvBackwardDataSpecialization,
K1,
K1,
MPerBlock,
NPerBlock,
KPerBlock,
true /* DoPadGemmM */,
true /* DoPadGemmN */>{};
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
ConvBackwardDataSpecialization,
K1,
K1,
MPerBlock,
NPerBlock,
KPerBlock,
true /* DoPadGemmM */,
true /* DoPadGemmN */,
ALayout,
BLayout,
ELayout>;
static auto GetDummyABDsEGridDescriptor()
static auto
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
{
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto ds_grid_desc_m_n = generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
},
Number<NumDTensor>{});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
const auto ds_grid_desc_m_n =
generate_tuple([&](auto) { return conv_to_gemm_transform.MakeCDescriptor_M_N(); },
Number<NumDTensor>{});
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
}
// desc
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor());
constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
@@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths,
/*ds_g_n_c_wis_lengths*/,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
@@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
@@ -382,68 +321,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
tildes = {i_ztilde, i_ytilde, i_xtilde};
}
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes};
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
static_assert(is_same_v<DLayout, ELayout>);
ConvToGemmBwdDataTransform conv_to_gemm_transform_d{
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
tildes};
ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
});
const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
// for check validity
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
@@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
BElementwiseOp b_element_op_;
CDEElementwiseOp cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
};

View File

@@ -54,15 +54,16 @@ template <typename GridwiseGemm,
typename ABDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AElementwiseOp,
typename BElementwiseOp,
typename CDEElementwiseOp,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
@@ -73,10 +74,9 @@ __global__ void
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AElementwiseOp a_element_op,
const BElementwiseOp b_element_op,
const CDEElementwiseOp cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
@@ -84,24 +84,29 @@ __global__ void
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t a_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
DsPointer p_ds_grid_grp;
@@ -112,10 +117,10 @@ __global__ void
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
@@ -130,7 +135,6 @@ __global__ void
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = batch_count;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
@@ -139,6 +143,7 @@ __global__ void
ignore = b_element_op;
ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = compute_ptr_offset_of_n;
ignore = block_2_ctile_map;
#endif
}
@@ -233,82 +238,54 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto transform_conv_to_gemm =
TransformConvBwdDataToGemm_v1<NDimSpatial,
ConvBackwardDataSpecialization,
AK1,
BK1,
MPerBlock,
NPerBlock,
KPerBlock,
DoPadGemmM,
DoPadGemmN>{};
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
ConvBackwardDataSpecialization,
AK1,
BK1,
MPerBlock,
NPerBlock,
KPerBlock,
DoPadGemmM,
DoPadGemmN,
ALayout,
BLayout,
ELayout,
true, /*SplitConvN*/
ABDataType,
EDataType>;
static auto GetDummyABDsEGridDescriptor()
static auto
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
{
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
const auto ds_grid_desc_m_n = generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
using ConvToGemmBwdDataTransformD =
TransformConvBwdDataToGemm_v1<NDimSpatial,
ConvBackwardDataSpecialization,
AK1,
BK1,
MPerBlock,
NPerBlock,
KPerBlock,
DoPadGemmM,
DoPadGemmN,
ALayout,
BLayout,
DLayout,
true, /*SplitConvN*/
ABDataType,
DDataType>;
return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N();
},
Number<NumDTensor>{});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
@@ -377,7 +354,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// desc
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor());
constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
@@ -431,15 +409,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
@@ -450,11 +421,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
});
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
});
@@ -526,68 +492,65 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes};
conv_N_per_block_ = conv_to_gemm_transform_.N_;
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
using ConvToGemmBwdDataTransformD =
TransformConvBwdDataToGemm_v1<NDimSpatial,
ConvBackwardDataSpecialization,
AK1,
BK1,
MPerBlock,
NPerBlock,
KPerBlock,
DoPadGemmM,
DoPadGemmN,
ALayout,
BLayout,
DLayout,
true, /*SplitConvN*/
ABDataType,
DDataType>;
ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
tildes};
ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
});
const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
// desc for problem definition
const auto a_grid_desc_m_k =
@@ -628,6 +591,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
}
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_k_wos_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_c_wis_strides[1] * conv_N_per_block_;
}
void Print() const
@@ -660,6 +630,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// tensor descriptor for problem definition
index_t num_group_;
index_t conv_N_per_block_;
std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
@@ -678,23 +649,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
// element-wise op
AElementwiseOp a_element_op_;
BElementwiseOp b_element_op_;
CDEElementwiseOp cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
};
@@ -711,8 +675,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.Print();
}
float ave_time = 0;
const index_t gdy = arg.num_group_;
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_k_wos_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdz = num_workgroups_per_Conv_N;
float ave_time = 0;
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
@@ -724,9 +692,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
throw std::runtime_error("wrong! device_op has invalid setting");
}
const index_t grid_size = arg.block_2_etile_map_container_[i].CalculateGridSize(
arg.e_grid_desc_m_n_container_[i]) *
arg.num_group_;
const index_t gdx = arg.block_2_etile_map_container_[i].CalculateGridSize(
arg.e_grid_desc_m_n_container_[i]);
const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
@@ -747,12 +714,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
has_main_loop>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
@@ -762,13 +730,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_k_wos_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_container_[i],
arg.b_grid_desc_bk0_n_bk1_container_[i],
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
arg.block_2_etile_map_container_[i],
arg.compute_ptr_offset_of_batch_);
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK))