Pool3d fwd (#697)

* Expand the base class of pool2d, prepare to share base class with pool3d

* Add pool3d device op

* Add pool3d f16 example

* Refactor the base class. implement generic pooling in the future

* clang format

* get original index in max pooling

* Add outputindex to base class

* Fix dimension

* Add pooling instance

* Use indexType instead

* Remove useless header

* Extract IndexDataType to template

* Extract pooling reference code

* clang format

* clang format

* Fix typo

* Add tensor stride

* Add missing header

* Add index stride and output stride

* Refine naming

* Add type to base class

* Rename file

* Use proper size

* Fix typo

* Refine naming

* Modify the argument into vector.

* Add max pool profiler

* Refine naming

* Support f32 pool

* Fix typo

* Add avg pool2d fwd in profiler

* clang format

* Rename AccDatatype to ComputeDatatype

* Fix init

* test pool

* Extract variable

* Add client example

* Check the pooling dim

* clang format

* Connect argv and arg_parser

* Add found check

* Remove useless header

* Refine naming

* Adjust the order of device_pool_fwd
This commit is contained in:
rocking
2023-05-24 22:05:04 +08:00
committed by GitHub
parent d821d1e54f
commit 76ec0089fb
44 changed files with 3226 additions and 241 deletions

View File

@@ -0,0 +1,345 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <index_t InOutRank,
index_t WindowRank,
typename InDataType,
typename OutDataType,
typename ComputeDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
bool OutputIndex>
struct ReferencePoolingFwd : public device::BaseOperator
{
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& /*in_right_pads*/)
: in_(in),
out_(out),
out_indices_(out_indices),
window_spatial_lengths_(window_spatial_lengths),
window_strides_(window_strides),
in_left_pads_(in_left_pads),
reduceLength_(1)
{
static_for<0, WindowRank, 1>{}(
[&](auto I) { reduceLength_ *= window_spatial_lengths[I]; });
}
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
Tensor<IndexDataType>& out_indices_;
const std::vector<ck::index_t>& window_spatial_lengths_;
const std::vector<ck::index_t>& window_strides_;
const std::vector<ck::index_t>& in_left_pads_;
int reduceLength_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float RunPooling3dFwd(const Argument& arg)
{
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
arg.reduceLength_);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation = ck::detail::
AccumulateWithNanCheck<PropagateNan, ReduceOperation, ComputeDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_ncdhw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3],
arg.out_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
ComputeDataType,
IndexDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
arg.out_indices_(n, c, do_, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_ncdhw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3],
arg.out_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
};
return 0;
}
float RunPooling2dFwd(const Argument& arg)
{
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
arg.reduceLength_);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation = ck::detail::
AccumulateWithNanCheck<PropagateNan, ReduceOperation, ComputeDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_nchw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
ComputeDataType,
IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
arg.out_indices_(n, c, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_nchw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
};
return 0;
}
float Run(const Argument& arg)
{
// TODO - support generic pooling
if constexpr(InOutRank == 5 && WindowRank == 3)
return RunPooling3dFwd(arg);
else if constexpr(InOutRank == 4 && WindowRank == 2)
return RunPooling2dFwd(arg);
else
throw std::runtime_error("Only support pooling3d or pooling2d so far");
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& in_right_pads)
{
return Argument{in,
out,
out_indices,
window_spatial_lengths,
window_strides,
in_left_pads,
in_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferencePoolingFwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,111 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto InOutRank = 4;
static constexpr auto WindowRank = 2;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
// FP16
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, AvgOp, false>>>&);
// FP16 - return index
void add_device_pool2d_fwd_nhwc_index_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
// FP32
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, AvgOp, false>>>&);
// FP32 - return index
void add_device_pool2d_fwd_nhwc_index_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
ReduceOpId,
OutputIndex>>
{
using DeviceOp = DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
ReduceOpId,
OutputIndex>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f16_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f32_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,111 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto InOutRank = 5;
static constexpr auto WindowRank = 3;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
// FP16
void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, AvgOp, false>>>&);
// FP16 - return index
void add_device_pool3d_fwd_ndhwc_index_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
// FP32
void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, AvgOp, false>>>&);
// FP32 - return index
void add_device_pool3d_fwd_ndhwc_index_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
ReduceOpId,
OutputIndex>>
{
using DeviceOp = DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
ReduceOpId,
OutputIndex>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -90,6 +90,7 @@ void add_device_reduce_instance_threadwise(
AccElementwiseOp,
PropagateNan,
OutputIndex,
false,
false, // HaveIndexInputIfOutputIndex
cfg1::BlockSize_,
cfg2::MThreadSliceSize_,

View File

@@ -411,6 +411,12 @@ struct Tensor
}
}
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
template <typename... Is>
T& operator()(Is... is)
{