mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Grouped 3d conv backward data support (#799)
* Grouped 3d conv backward data support * Fix comments
This commit is contained in:
@@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
CDEElementwiseOp>
|
||||
{
|
||||
// FIXME
|
||||
static_assert(NDimSpatial == 2, "wrong! only implemented for 2D now");
|
||||
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
|
||||
"wrong! only implemented for 2D and 3D now");
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
|
||||
|
||||
@@ -491,130 +492,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
|
||||
});
|
||||
|
||||
static constexpr auto NonSpatialDimsNum = Number<3>{};
|
||||
|
||||
static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
|
||||
static constexpr auto HIdx =
|
||||
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
|
||||
static constexpr auto WIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
|
||||
: Number<NonSpatialDimsNum + 2>{};
|
||||
|
||||
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
|
||||
static constexpr auto YIdx =
|
||||
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
|
||||
static constexpr auto XIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
|
||||
: Number<NonSpatialDimsNum + 2>{};
|
||||
|
||||
// problem definition
|
||||
const index_t Y = b_g_k_c_xs_lengths[3];
|
||||
const index_t X = b_g_k_c_xs_lengths[4];
|
||||
const index_t Z = b_g_k_c_xs_lengths[ZIdx];
|
||||
const index_t Y = b_g_k_c_xs_lengths[YIdx];
|
||||
const index_t X = b_g_k_c_xs_lengths[XIdx];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides_[0];
|
||||
const index_t ConvStrideW = conv_filter_strides_[1];
|
||||
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
|
||||
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
|
||||
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations_[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations_[1];
|
||||
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
|
||||
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
|
||||
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
|
||||
|
||||
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1;
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
|
||||
for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
|
||||
{
|
||||
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
|
||||
|
||||
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
|
||||
{
|
||||
// check slice is valid
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
if(YDotSlice * XDotSlice <= 0)
|
||||
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
// check slice is valid
|
||||
const auto ZDotSlice =
|
||||
NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1;
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{i_ytilde, i_xtilde});
|
||||
if(YDotSlice * XDotSlice * ZDotSlice <= 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{i_ytilde, i_xtilde});
|
||||
std::array<index_t, NDimSpatial> tildes;
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
tildes = {i_ytilde, i_xtilde};
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
tildes = {i_ztilde, i_ytilde, i_xtilde};
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
|
||||
}
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
// populate Ds desc
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
ds_grid_desc_m_n(i) =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_c_wis_lengths[i],
|
||||
ds_g_n_c_wis_strides[i],
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{i_ytilde, i_xtilde});
|
||||
});
|
||||
tildes);
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{i_ytilde, i_xtilde});
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
|
||||
// desc for problem definition
|
||||
const auto a_grid_desc_m_k = transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
|
||||
const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
|
||||
b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
|
||||
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
|
||||
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
|
||||
// populate Ds desc
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
// desc for blockwise copy
|
||||
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
|
||||
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
|
||||
ds_grid_desc_m_n(i) =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_c_wis_lengths[i],
|
||||
ds_g_n_c_wis_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
});
|
||||
|
||||
// block-to-e-tile-map
|
||||
auto block_2_etile_map =
|
||||
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
|
||||
block_2_etile_map_container_.push_back(block_2_etile_map);
|
||||
// desc for problem definition
|
||||
const auto a_grid_desc_m_k =
|
||||
transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
|
||||
const auto b_grid_desc_n_k =
|
||||
transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
|
||||
b_grid_desc_n_k,
|
||||
ds_grid_desc_m_n,
|
||||
e_grid_desc_m_n,
|
||||
block_2_etile_map))
|
||||
{
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n));
|
||||
a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
|
||||
b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
|
||||
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
|
||||
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
|
||||
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n));
|
||||
// desc for blockwise copy
|
||||
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
|
||||
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
|
||||
|
||||
// block-to-e-tile-map
|
||||
auto block_2_etile_map =
|
||||
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
|
||||
|
||||
block_2_etile_map_container_.push_back(block_2_etile_map);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
|
||||
b_grid_desc_n_k,
|
||||
ds_grid_desc_m_n,
|
||||
e_grid_desc_m_n,
|
||||
block_2_etile_map))
|
||||
{
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
|
||||
GridwiseGemm::
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n));
|
||||
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -803,7 +846,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// vector load for A matrix from global memory to LDS
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
|
||||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
|
||||
{
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
@@ -816,7 +861,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
|
||||
// vector load for B matrix from global memory to LDS
|
||||
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
|
||||
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
|
||||
{
|
||||
if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
@@ -835,7 +881,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
if constexpr(is_same_v<DLayout, tensor_layout::convolution::GNHWC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::GNDHWC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::NHWGC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::NDHWGC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::G_NHW_C> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::GC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::G_C>)
|
||||
@@ -859,7 +907,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// vector store for E
|
||||
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NHWGC>)
|
||||
is_same_v<ELayout, tensor_layout::convolution::GNDHWC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NHWGC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NDHWGC>)
|
||||
{
|
||||
// vector store C matrix into global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
|
||||
Reference in New Issue
Block a user