Pool2d max/avg kernel in the BWD version (#1494)

* Add pool2d instance BWD AVG

* Add pool2d instance BWD MAX

* Fix: avg review

* Fix review: part2

* Fix - enable test when type is compiled

* Fix review part3
This commit is contained in:
Mateusz Ozga
2024-09-12 11:47:52 +02:00
committed by GitHub
parent e8d2887cb2
commit 448c0f56d8
25 changed files with 2168 additions and 16 deletions

View File

@@ -0,0 +1,523 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// In and Din = [N, C, Hi, Wi]
// Out and Dout = [N, C, Ho, Wo]
// Out = AvgPool2dFwd(In)
// Din = AvgPool2dBwd(Dout)
// Pooling dimension = H, W
template <typename DOutDataType,
typename DInDataType,
typename ComputeDataType,
ck::index_t BlockSize,
ck::index_t MThreadClusterSize,
ck::index_t KThreadClusterSize,
ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DeviceAvgPool2dBwd_NHWC_NHWC : public DeviceAvgPoolBwd<2,
DOutDataType,
DInDataType,
tensor_layout::convolution::NHWC,
tensor_layout::convolution::NHWC>
{
static constexpr ck::index_t NDimSpatial = 2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto
Make2DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_c_wos_lengths,
const std::vector<ck::index_t>& din_n_c_wos_length,
const std::vector<ck::index_t>& dout_n_c_wos_strides,
const std::vector<ck::index_t>& din_n_c_wos_strides,
const std::vector<ck::index_t>& window_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads,
const std::vector<ck::index_t>& tildes)
{
index_t i_ytilde = tildes[0];
index_t i_xtilde = tildes[1];
const index_t N = dout_n_c_wos_lengths[0];
const index_t C = dout_n_c_wos_lengths[1];
const index_t Ho = dout_n_c_wos_lengths[2];
const index_t Wo = dout_n_c_wos_lengths[3];
const index_t Hi = din_n_c_wos_length[2];
const index_t Wi = din_n_c_wos_length[3];
const index_t Y = window_lengths[0];
const index_t X = window_lengths[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t ConvStrideH = window_strides[0];
const index_t ConvStrideW = window_strides[1];
const index_t ConvDilationH = window_dilations[0];
const index_t ConvDilationW = window_dilations[1];
const index_t Ni_stride = dout_n_c_wos_strides[0];
const index_t Ci_stride = dout_n_c_wos_strides[1];
const index_t Ho_stride = dout_n_c_wos_strides[2];
const index_t Wo_stride = dout_n_c_wos_strides[3];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on Tildes that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildeSliceEnd =
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd =
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// ReduceK is different for each Reduce
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// Problem size of reduction kernel
const index_t MRaw = N * HTildeSlice * WTildeSlice * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
const index_t KRaw = YDotSlice * XDotSlice;
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
const auto out_n_ho_wo_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Ho, Wo, C), make_tuple(Ni_stride, Ho_stride, Wo_stride, Ci_stride));
// Out[ReduceM, ReduceK]
const auto out_n_hop_wop_c_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_c_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}));
const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C)),
make_merge_transform(make_tuple(YDotSlice, XDotSlice))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor(
out_grid_desc_reducemraw_reducekraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// In[ReduceM]
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(din_n_c_wos_strides[0],
din_n_c_wos_strides[2],
din_n_c_wos_strides[3],
din_n_c_wos_strides[1]));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_grid_desc_reducemraw = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto in_grid_desc_reducem =
transform_tensor_descriptor(in_grid_desc_reducemraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem);
}
using DoutDinGridDesc = decltype(Make2DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0},
{0, 0, 0, 0},
{0, 0, 0, 0},
{0, 0, 0, 0},
{0, 0},
{0, 0},
{0, 0},
{0, 0},
{0, 0},
{0, 0}));
using DoutGridDesc_M_K = remove_cvref_t<tuple_element_t<0, DoutDinGridDesc>>;
using DinGridDesc_M = remove_cvref_t<tuple_element_t<1, DoutDinGridDesc>>;
// FIXME
// for NHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K
using PassThrough = tensor_operation::element_wise::PassThrough;
using Div = tensor_operation::element_wise::UnaryDivide;
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<DOutDataType,
DInDataType,
ComputeDataType,
int,
DoutGridDesc_M_K,
DinGridDesc_M,
reduce::Add,
PassThrough,
Div,
InMemoryDataOperationEnum::Set,
false, // propagate_nan
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
OutSrcInDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
struct Argument : public BaseArgument
{
Argument(const DOutDataType* p_dout,
DInDataType* p_din,
std::vector<ck::index_t> dout_n_c_wos_lengths,
std::vector<ck::index_t> din_n_c_wos_length,
std::vector<ck::index_t> dout_n_c_wos_strides,
std::vector<ck::index_t> din_n_c_wos_strides,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
: p_dout_grid_{p_dout},
p_din_grid_{p_din},
dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
din_n_c_wos_length_{din_n_c_wos_length},
dout_n_c_wos_strides_{dout_n_c_wos_strides},
din_n_c_wos_strides_{din_n_c_wos_strides},
num_reduce_{1},
div_element_op_{window_lengths[0] * window_lengths[1]}
{
std::vector<ck::index_t> Tildes(NDimSpatial);
for(int i = 0; i < NDimSpatial; ++i)
{
int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]);
Tildes[i] = window_strides[i] / GcdStrideDilation;
num_reduce_ *= Tildes[i];
}
for(index_t i_ytilde = 0; i_ytilde < Tildes[0]; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < Tildes[1]; ++i_xtilde)
{
const auto YDotSlice =
math::integer_divide_ceil(window_lengths[0] - i_ytilde, Tildes[0]);
const auto XDotSlice =
math::integer_divide_ceil(window_lengths[1] - i_xtilde, Tildes[1]);
if(YDotSlice * XDotSlice <= 0)
{
continue;
}
const auto dout_din_grid_desc =
Make2DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths,
din_n_c_wos_length,
dout_n_c_wos_strides,
din_n_c_wos_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]);
din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]);
}
}
}
const DOutDataType* p_dout_grid_;
DInDataType* p_din_grid_;
std::vector<ck::index_t> dout_n_c_wos_lengths_;
std::vector<ck::index_t> din_n_c_wos_length_;
std::vector<ck::index_t> dout_n_c_wos_strides_;
std::vector<ck::index_t> din_n_c_wos_strides_;
int num_reduce_;
std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
std::vector<DinGridDesc_M> din_grid_desc_m_container_;
Div div_element_op_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(index_t i = 0; i < arg.num_reduce_; i++)
{
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
false,
false,
false, // don't have index input
DOutDataType,
DInDataType,
ComputeDataType,
int,
DoutGridDesc_M_K,
DinGridDesc_M,
PassThrough,
Div>;
ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0);
const index_t grid_size = (M / M_BlockTileSize);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.dout_grid_desc_m_k_container_[i],
arg.din_grid_desc_m_container_[i],
PassThrough{},
arg.div_element_op_,
float(1),
arg.p_dout_grid_,
nullptr,
float(0),
arg.p_din_grid_,
nullptr);
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
constexpr index_t Rank = NDimSpatial + 2;
int doutFastestDim = -1;
int dinFastestDim = -1;
for(int i = 0; i < Rank; ++i)
{
if(arg.dout_n_c_wos_strides_[i] == 1)
doutFastestDim = i;
if(arg.din_n_c_wos_strides_[i] == 1)
dinFastestDim = i;
}
if(InSrcOutDstVectorSize != 1 && (dinFastestDim != 1 || doutFastestDim != 1))
{
return false;
}
if(doutFastestDim == -1 || dinFastestDim == -1)
{
if constexpr(InSrcOutDstVectorSize != 1)
return false;
}
else
{
if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
return false;
if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
return false;
}
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
void* p_din,
std::vector<ck::index_t> dout_n_c_wos_lengths,
std::vector<ck::index_t> din_n_c_wos_length,
std::vector<ck::index_t> dout_n_c_wos_strides,
std::vector<ck::index_t> din_n_c_wos_strides,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) override
{
constexpr index_t Rank = NDimSpatial + 2;
if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
{
throw std::runtime_error("dimension of [dout|din]_n_c_wos_strides or "
"[dout|din]_n_c_wos_lengths is not equal to Rank");
}
if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
input_right_pads.size() != NDimSpatial)
{
throw std::runtime_error(
"dimension of [window_lengths, window_strides, window_dilations, input_left_pads, "
"input_right_pads] is not equal to Rank");
}
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
static_cast<DInDataType*>(p_din),
dout_n_c_wos_lengths,
din_n_c_wos_length,
dout_n_c_wos_strides,
din_n_c_wos_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceAvgPool2dBwd<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck