Grouped 3d conv backward data support (#799)

* Grouped 3d conv backward data support

* Fix comments
This commit is contained in:
Bartłomiej Kocot
2023-07-18 18:01:33 +02:00
committed by GitHub
parent f82bd59389
commit 49180fd60b
23 changed files with 1654 additions and 625 deletions

View File

@@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CDEElementwiseOp>
{
// FIXME
static_assert(NDimSpatial == 2, "wrong! only implemented for 2D now");
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
"wrong! only implemented for 2D and 3D now");
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
@@ -491,130 +492,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
});
static constexpr auto NonSpatialDimsNum = Number<3>{};
static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
static constexpr auto HIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
: Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
static constexpr auto YIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
: Number<NonSpatialDimsNum + 2>{};
// problem definition
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t Z = b_g_k_c_xs_lengths[ZIdx];
const index_t Y = b_g_k_c_xs_lengths[YIdx];
const index_t X = b_g_k_c_xs_lengths[XIdx];
const index_t ConvStrideH = conv_filter_strides_[0];
const index_t ConvStrideW = conv_filter_strides_[1];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations_[0];
const index_t ConvDilationW = conv_filter_dilations_[1];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
// check slice is valid
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(YDotSlice * XDotSlice <= 0)
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
continue;
}
// check slice is valid
const auto ZDotSlice =
NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1;
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
if(YDotSlice * XDotSlice * ZDotSlice <= 0)
{
continue;
}
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
std::array<index_t, NDimSpatial> tildes;
if constexpr(NDimSpatial == 2)
{
tildes = {i_ytilde, i_xtilde};
}
else if constexpr(NDimSpatial == 3)
{
tildes = {i_ztilde, i_ytilde, i_xtilde};
}
else
{
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
});
tildes);
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
// desc for problem definition
const auto a_grid_desc_m_k = transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
DsGridDesc_M_N ds_grid_desc_m_n;
a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
});
// block-to-e-tile-map
auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
block_2_etile_map_container_.push_back(block_2_etile_map);
// desc for problem definition
const auto a_grid_desc_m_k =
transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
const auto b_grid_desc_n_k =
transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n));
a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n));
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
// block-to-e-tile-map
auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
block_2_etile_map_container_.push_back(block_2_etile_map);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n));
}
}
}
}
@@ -803,7 +846,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector load for A matrix from global memory to LDS
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
{
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
{
@@ -816,7 +861,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// vector load for B matrix from global memory to LDS
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0))
{
@@ -835,7 +881,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(is_same_v<DLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<DLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<DLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<DLayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<DLayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<DLayout, tensor_layout::convolution::GC> ||
is_same_v<DLayout, tensor_layout::convolution::G_C>)
@@ -859,7 +907,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector store for E
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC>)
is_same_v<ELayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ELayout, tensor_layout::convolution::NDHWGC>)
{
// vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))