Refactor pool fwd (#815)

* Do not hardcode stride

* devicePool2DFwd Inherit devicePool3DFwd

* Move instance declaration out of common

* Add dilation

* use the pool3d rank, because pool2d inherit pooo3d

* calculate Do Ho Wo for the dilation

* Fix header name

* Modify ckProfiler

* Remove pool2d instance

* Remove pool2d in profiler

* Remove pool2d and add dilation

* In to client example, this commit revise following:
1. Add dilation.
2. Use pool3d to implement pool2d

* Refine naming and IsSupportedArgument()

* Add dilation to maxpool bwd example

* clang format

* 1. Remove useless header
2. Fix copyright
3. Refine naming

* Add layout parameter to pool fwd

* clang format

* Fix merge error

* Fix compile error

* Remove layout parameter in derived class

* Refine changlog

* Fix compile error

* Fix compiler error

* Add layout to external api and profiler

[ROCm/composable_kernel commit: f60f0a5e03]
This commit is contained in:
rocking
2023-08-15 02:25:28 +08:00
committed by GitHub
parent 759fb65188
commit 4ee825732f
41 changed files with 824 additions and 1574 deletions

View File

@@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& /*in_right_pads*/)
: in_(in),
@@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices_(out_indices),
window_spatial_lengths_(window_spatial_lengths),
window_strides_(window_strides),
window_dilations_(window_dilations),
in_left_pads_(in_left_pads),
reduceLength_(1)
{
@@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices_;
const std::vector<ck::index_t>& window_spatial_lengths_;
const std::vector<ck::index_t>& window_strides_;
const std::vector<ck::index_t>& window_dilations_;
const std::vector<ck::index_t>& in_left_pads_;
int reduceLength_;
};
@@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
ck::index_t di = do_ * arg.window_strides_[0] +
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
ck::index_t hi = ho * arg.window_strides_[1] +
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
ck::index_t wi = wo * arg.window_strides_[2] +
x * arg.window_dilations_[2] -
arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
@@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
ck::index_t di = do_ * arg.window_strides_[0] +
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
ck::index_t hi = ho * arg.window_strides_[1] +
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
ck::index_t wi = wo * arg.window_strides_[2] +
x * arg.window_dilations_[2] -
arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
@@ -202,10 +211,12 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
ck::index_t hi = ho * arg.window_strides_[0] +
y * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
ck::index_t wi = wo * arg.window_strides_[1] +
x * arg.window_dilations_[1] - arg.in_left_pads_[1];
if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
@@ -308,6 +319,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& in_right_pads)
{
@@ -316,6 +328,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices,
window_spatial_lengths,
window_strides,
window_dilations,
in_left_pads,
in_right_pads};
}

View File

@@ -1,114 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto InOutRank = 4;
static constexpr auto WindowRank = 2;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef __fp16__
// FP16
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, AvgOp, false>>>&);
// FP16 - return index
void add_device_pool2d_fwd_nhwc_index_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
#endif
#ifdef __fp32__
// FP32
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, AvgOp, false>>>&);
// FP32 - return index
void add_device_pool2d_fwd_nhwc_index_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
#endif
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
ReduceOpId,
OutputIndex>>
{
using DeviceOp = DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
ReduceOpId,
OutputIndex>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f16_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs);
}
}
#endif
#ifdef __fp32__
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f32_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -25,36 +25,38 @@ static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef __fp16__
// FP16
void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, false>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, AvgOp, false>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, AvgOp, false>>>&);
// FP16 - return index
void add_device_pool3d_fwd_ndhwc_index_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
#ifdef __fp32__
// FP32
void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, false>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, AvgOp, false>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, AvgOp, false>>>&);
// FP32 - return index
void add_device_pool3d_fwd_ndhwc_index_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
typename InLayout,
typename OutLayout,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
@@ -62,6 +64,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
InDataType,
OutDataType,
IndexDataType,
InLayout,
OutLayout,
ReduceOpId,
OutputIndex>>
{
@@ -70,40 +74,46 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
InDataType,
OutDataType,
IndexDataType,
InLayout,
OutLayout,
ReduceOpId,
OutputIndex>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
#ifdef __fp16__
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
{
add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs);
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
}
}
else
{
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
}
}
#endif
#ifdef __fp32__
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs);
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
}
}
else
{
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
}
}
#endif
}
return op_ptrs;
}
};