mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Refactor pool fwd (#815)
* Do not hardcode stride * devicePool2DFwd Inherit devicePool3DFwd * Move instance declaration out of common * Add dilation * use the pool3d rank, because pool2d inherit pooo3d * calculate Do Ho Wo for the dilation * Fix header name * Modify ckProfiler * Remove pool2d instance * Remove pool2d in profiler * Remove pool2d and add dilation * In to client example, this commit revise following: 1. Add dilation. 2. Use pool3d to implement pool2d * Refine naming and IsSupportedArgument() * Add dilation to maxpool bwd example * clang format * 1. Remove useless header 2. Fix copyright 3. Refine naming * Add layout parameter to pool fwd * clang format * Fix merge error * Fix compile error * Remove layout parameter in derived class * Refine changlog * Fix compile error * Fix compiler error * Add layout to external api and profiler
This commit is contained in:
@@ -18,7 +18,45 @@
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp"
|
||||
|
||||
template <typename InDataType,
|
||||
template <typename TensorLayout>
|
||||
std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
|
||||
ck::index_t C_,
|
||||
ck::index_t D,
|
||||
ck::index_t H,
|
||||
ck::index_t W,
|
||||
TensorLayout layout)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
(void)N_;
|
||||
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
|
||||
return {C_ * D * H * W, D * H * W, H * W, W, 1_uz};
|
||||
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
|
||||
return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_};
|
||||
};
|
||||
|
||||
template <typename TensorLayout>
|
||||
HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
|
||||
std::size_t C_,
|
||||
std::size_t D,
|
||||
std::size_t H,
|
||||
std::size_t W,
|
||||
TensorLayout layout)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz});
|
||||
}
|
||||
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, D, H, W},
|
||||
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DevicePoolFwdInstance,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename ComputeDataType,
|
||||
typename IndexDataType,
|
||||
@@ -40,6 +78,9 @@ bool pool3d_test(bool do_verification,
|
||||
ck::index_t window_stride_d,
|
||||
ck::index_t window_stride_h,
|
||||
ck::index_t window_stride_w,
|
||||
ck::index_t window_dilation_d,
|
||||
ck::index_t window_dilation_h,
|
||||
ck::index_t window_dilation_w,
|
||||
ck::index_t in_left_pad_d,
|
||||
ck::index_t in_left_pad_h,
|
||||
ck::index_t in_left_pad_w,
|
||||
@@ -47,53 +88,21 @@ bool pool3d_test(bool do_verification,
|
||||
ck::index_t in_right_pad_h,
|
||||
ck::index_t in_right_pad_w)
|
||||
{
|
||||
using DevicePoolFwdInstance =
|
||||
ck::tensor_operation::device::DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<
|
||||
InDataType, // InDataType
|
||||
OutDataType, // OutDataType
|
||||
IndexDataType, // IndexDataType
|
||||
ComputeDataType, // ComputeDataType
|
||||
ReduceOpId,
|
||||
OutputIndex,
|
||||
64, // BlockSize
|
||||
64, // ReduceMThreadClusterSize
|
||||
1, // ReduceKThreadClusterSize
|
||||
4, // ReduceMThreadSliceSize
|
||||
1, // ReduceKThreadSliceSize
|
||||
4>; // InSrcOutDstVectorSize
|
||||
|
||||
const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1;
|
||||
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
|
||||
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
|
||||
const ck::index_t Zs = (Z - 1) * window_dilation_d + 1;
|
||||
const ck::index_t Ys = (Y - 1) * window_dilation_h + 1;
|
||||
const ck::index_t Xs = (X - 1) * window_dilation_w + 1;
|
||||
const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1;
|
||||
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1;
|
||||
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1;
|
||||
|
||||
const std::vector<ck::index_t> window_spatial_lengths{Z, Y, X};
|
||||
const std::vector<ck::index_t> window_strides{
|
||||
window_stride_d, window_stride_h, window_stride_w};
|
||||
const std::vector<ck::index_t> window_dilations{
|
||||
window_dilation_d, window_dilation_h, window_dilation_w};
|
||||
const std::vector<ck::index_t> input_left_pads{in_left_pad_d, in_left_pad_h, in_left_pad_w};
|
||||
const std::vector<ck::index_t> input_right_pads{in_right_pad_d, in_right_pad_h, in_right_pad_w};
|
||||
|
||||
// tensor layout
|
||||
auto f_host_tensor_descriptor = [](std::size_t N_,
|
||||
std::size_t C_,
|
||||
std::size_t D,
|
||||
std::size_t H,
|
||||
std::size_t W,
|
||||
auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, D, H, W},
|
||||
{C_ * D * H * W, D * H * W, H * W, W, 1_uz});
|
||||
}
|
||||
else if constexpr(ck::is_same<decltype(layout),
|
||||
ck::tensor_layout::convolution::NDHWC>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, D, H, W},
|
||||
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<InDataType> in_n_c_di_hi_wi(f_host_tensor_descriptor(N, C, Di, Hi, Wi, InLayout{}));
|
||||
Tensor<OutDataType> out_n_c_do_ho_wo_host(
|
||||
f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{}));
|
||||
@@ -126,10 +135,11 @@ bool pool3d_test(bool do_verification,
|
||||
{N, C, Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{N, C, Do, Ho, Wo},
|
||||
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C},
|
||||
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
|
||||
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
|
||||
f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, InLayout{}),
|
||||
f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}),
|
||||
f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}),
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{2, 3, 4});
|
||||
@@ -165,6 +175,7 @@ bool pool3d_test(bool do_verification,
|
||||
out_indices_n_c_do_ho_wo_host,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
|
||||
@@ -27,31 +27,49 @@ static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
static constexpr bool OutputIndex = false;
|
||||
static constexpr bool PropagateNan = false;
|
||||
|
||||
using DevicePoolFwdInstance =
|
||||
ck::tensor_operation::device::DevicePool3dFwd_NDHWC_NDHWC<InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ComputeDataType,
|
||||
ReduceOpId,
|
||||
OutputIndex,
|
||||
64, // BlockSize
|
||||
64, // ReduceMThreadClusterSize
|
||||
1, // ReduceKThreadClusterSize
|
||||
1, // ReduceMThreadSliceSize
|
||||
1, // ReduceKThreadSliceSize
|
||||
1>; // InSrcOutDstVectorSize
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// Pool shape
|
||||
ck::index_t N = 2;
|
||||
ck::index_t C = 32;
|
||||
ck::index_t Z = 2;
|
||||
ck::index_t Y = 2;
|
||||
ck::index_t X = 2;
|
||||
ck::index_t Di = 30;
|
||||
ck::index_t Hi = 30;
|
||||
ck::index_t Wi = 30;
|
||||
ck::index_t window_stride_d = 2;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t in_left_pad_d = 1;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_d = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
ck::index_t N = 2;
|
||||
ck::index_t C = 32;
|
||||
ck::index_t Z = 2;
|
||||
ck::index_t Y = 2;
|
||||
ck::index_t X = 2;
|
||||
ck::index_t Di = 30;
|
||||
ck::index_t Hi = 30;
|
||||
ck::index_t Wi = 30;
|
||||
ck::index_t window_stride_d = 2;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t window_dilation_d = 1;
|
||||
ck::index_t window_dilation_h = 1;
|
||||
ck::index_t window_dilation_w = 1;
|
||||
ck::index_t in_left_pad_d = 1;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_d = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
|
||||
bool pass = pool3d_test<InDataType,
|
||||
bool pass = pool3d_test<DevicePoolFwdInstance,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
IndexDataType,
|
||||
@@ -72,6 +90,9 @@ int main()
|
||||
window_stride_d,
|
||||
window_stride_h,
|
||||
window_stride_w,
|
||||
window_dilation_d,
|
||||
window_dilation_h,
|
||||
window_dilation_w,
|
||||
in_left_pad_d,
|
||||
in_left_pad_h,
|
||||
in_left_pad_w,
|
||||
|
||||
Reference in New Issue
Block a user