mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Refactor pool fwd (#815)
* Do not hardcode stride
* devicePool2DFwd Inherit devicePool3DFwd
* Move instance declaration out of common
* Add dilation
* use the pool3d rank, because pool2d inherit pooo3d
* calculate Do Ho Wo for the dilation
* Fix header name
* Modify ckProfiler
* Remove pool2d instance
* Remove pool2d in profiler
* Remove pool2d and add dilation
* In to client example, this commit revise following:
1. Add dilation.
2. Use pool3d to implement pool2d
* Refine naming and IsSupportedArgument()
* Add dilation to maxpool bwd example
* clang format
* 1. Remove useless header
2. Fix copyright
3. Refine naming
* Add layout parameter to pool fwd
* clang format
* Fix merge error
* Fix compile error
* Remove layout parameter in derived class
* Refine changlog
* Fix compile error
* Fix compiler error
* Add layout to external api and profiler
[ROCm/composable_kernel commit: f60f0a5e03]
This commit is contained in:
@@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
Tensor<IndexDataType>& out_indices,
|
||||
const std::vector<ck::index_t>& window_spatial_lengths,
|
||||
const std::vector<ck::index_t>& window_strides,
|
||||
const std::vector<ck::index_t>& window_dilations,
|
||||
const std::vector<ck::index_t>& in_left_pads,
|
||||
const std::vector<ck::index_t>& /*in_right_pads*/)
|
||||
: in_(in),
|
||||
@@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
out_indices_(out_indices),
|
||||
window_spatial_lengths_(window_spatial_lengths),
|
||||
window_strides_(window_strides),
|
||||
window_dilations_(window_dilations),
|
||||
in_left_pads_(in_left_pads),
|
||||
reduceLength_(1)
|
||||
{
|
||||
@@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
Tensor<IndexDataType>& out_indices_;
|
||||
const std::vector<ck::index_t>& window_spatial_lengths_;
|
||||
const std::vector<ck::index_t>& window_strides_;
|
||||
const std::vector<ck::index_t>& window_dilations_;
|
||||
const std::vector<ck::index_t>& in_left_pads_;
|
||||
int reduceLength_;
|
||||
};
|
||||
@@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
|
||||
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
|
||||
{
|
||||
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
|
||||
ck::index_t di = do_ * arg.window_strides_[0] +
|
||||
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
|
||||
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
|
||||
{
|
||||
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
|
||||
ck::index_t hi = ho * arg.window_strides_[1] +
|
||||
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
|
||||
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
|
||||
{
|
||||
ck::index_t wi =
|
||||
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
|
||||
ck::index_t wi = wo * arg.window_strides_[2] +
|
||||
x * arg.window_dilations_[2] -
|
||||
arg.in_left_pads_[2];
|
||||
if(di >= 0 &&
|
||||
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
|
||||
hi >= 0 &&
|
||||
@@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
|
||||
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
|
||||
{
|
||||
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
|
||||
ck::index_t di = do_ * arg.window_strides_[0] +
|
||||
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
|
||||
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
|
||||
{
|
||||
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
|
||||
ck::index_t hi = ho * arg.window_strides_[1] +
|
||||
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
|
||||
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
|
||||
{
|
||||
ck::index_t wi =
|
||||
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
|
||||
ck::index_t wi = wo * arg.window_strides_[2] +
|
||||
x * arg.window_dilations_[2] -
|
||||
arg.in_left_pads_[2];
|
||||
if(di >= 0 &&
|
||||
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
|
||||
hi >= 0 &&
|
||||
@@ -202,10 +211,12 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
|
||||
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
|
||||
{
|
||||
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
|
||||
ck::index_t hi = ho * arg.window_strides_[0] +
|
||||
y * arg.window_dilations_[0] - arg.in_left_pads_[0];
|
||||
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
|
||||
{
|
||||
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
|
||||
ck::index_t wi = wo * arg.window_strides_[1] +
|
||||
x * arg.window_dilations_[1] - arg.in_left_pads_[1];
|
||||
if(hi >= 0 &&
|
||||
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
|
||||
wi >= 0 &&
|
||||
@@ -308,6 +319,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
Tensor<IndexDataType>& out_indices,
|
||||
const std::vector<ck::index_t>& window_spatial_lengths,
|
||||
const std::vector<ck::index_t>& window_strides,
|
||||
const std::vector<ck::index_t>& window_dilations,
|
||||
const std::vector<ck::index_t>& in_left_pads,
|
||||
const std::vector<ck::index_t>& in_right_pads)
|
||||
{
|
||||
@@ -316,6 +328,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
out_indices,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads};
|
||||
}
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto InOutRank = 4;
|
||||
static constexpr auto WindowRank = 2;
|
||||
|
||||
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
|
||||
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
|
||||
#ifdef __fp16__
|
||||
// FP16
|
||||
void add_device_pool2d_fwd_nhwc_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, false>>>&);
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, AvgOp, false>>>&);
|
||||
|
||||
// FP16 - return index
|
||||
void add_device_pool2d_fwd_nhwc_index_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
// FP32
|
||||
void add_device_pool2d_fwd_nhwc_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, false>>>&);
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, AvgOp, false>>>&);
|
||||
|
||||
// FP32 - return index
|
||||
void add_device_pool2d_fwd_nhwc_index_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
|
||||
#endif
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
ck::ReduceTensorOp ReduceOpId,
|
||||
bool OutputIndex>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
|
||||
WindowRank,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpId,
|
||||
OutputIndex>>
|
||||
{
|
||||
using DeviceOp = DevicePoolFwd<InOutRank,
|
||||
WindowRank,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpId,
|
||||
OutputIndex>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp16__
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_index_f16_instances(op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_index_f32_instances(op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -25,36 +25,38 @@ static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
|
||||
#ifdef __fp16__
|
||||
// FP16
|
||||
void add_device_pool3d_fwd_ndhwc_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, false>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, false>>>&);
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, AvgOp, false>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, AvgOp, false>>>&);
|
||||
|
||||
// FP16 - return index
|
||||
void add_device_pool3d_fwd_ndhwc_index_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
// FP32
|
||||
void add_device_pool3d_fwd_ndhwc_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, false>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, MaxOp, false>>>&);
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, AvgOp, false>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, AvgOp, false>>>&);
|
||||
|
||||
// FP32 - return index
|
||||
void add_device_pool3d_fwd_ndhwc_index_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, MaxOp, true>>>&);
|
||||
#endif
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename InLayout,
|
||||
typename OutLayout,
|
||||
ck::ReduceTensorOp ReduceOpId,
|
||||
bool OutputIndex>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
|
||||
@@ -62,6 +64,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
|
||||
InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
InLayout,
|
||||
OutLayout,
|
||||
ReduceOpId,
|
||||
OutputIndex>>
|
||||
{
|
||||
@@ -70,40 +74,46 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
|
||||
InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
InLayout,
|
||||
OutLayout,
|
||||
ReduceOpId,
|
||||
OutputIndex>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp16__
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>)
|
||||
{
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
#ifdef __fp16__
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs);
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs);
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user