mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 16:04:58 +00:00
Refactor pool fwd (#815)
* Do not hardcode stride
* devicePool2DFwd Inherit devicePool3DFwd
* Move instance declaration out of common
* Add dilation
* use the pool3d rank, because pool2d inherit pooo3d
* calculate Do Ho Wo for the dilation
* Fix header name
* Modify ckProfiler
* Remove pool2d instance
* Remove pool2d in profiler
* Remove pool2d and add dilation
* In to client example, this commit revise following:
1. Add dilation.
2. Use pool3d to implement pool2d
* Refine naming and IsSupportedArgument()
* Add dilation to maxpool bwd example
* clang format
* 1. Remove useless header
2. Fix copyright
3. Refine naming
* Add layout parameter to pool fwd
* clang format
* Fix merge error
* Fix compile error
* Remove layout parameter in derived class
* Refine changlog
* Fix compile error
* Fix compiler error
* Add layout to external api and profiler
[ROCm/composable_kernel commit: f60f0a5e03]
This commit is contained in:
@@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
Tensor<IndexDataType>& out_indices,
|
||||
const std::vector<ck::index_t>& window_spatial_lengths,
|
||||
const std::vector<ck::index_t>& window_strides,
|
||||
const std::vector<ck::index_t>& window_dilations,
|
||||
const std::vector<ck::index_t>& in_left_pads,
|
||||
const std::vector<ck::index_t>& /*in_right_pads*/)
|
||||
: in_(in),
|
||||
@@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
out_indices_(out_indices),
|
||||
window_spatial_lengths_(window_spatial_lengths),
|
||||
window_strides_(window_strides),
|
||||
window_dilations_(window_dilations),
|
||||
in_left_pads_(in_left_pads),
|
||||
reduceLength_(1)
|
||||
{
|
||||
@@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
Tensor<IndexDataType>& out_indices_;
|
||||
const std::vector<ck::index_t>& window_spatial_lengths_;
|
||||
const std::vector<ck::index_t>& window_strides_;
|
||||
const std::vector<ck::index_t>& window_dilations_;
|
||||
const std::vector<ck::index_t>& in_left_pads_;
|
||||
int reduceLength_;
|
||||
};
|
||||
@@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
|
||||
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
|
||||
{
|
||||
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
|
||||
ck::index_t di = do_ * arg.window_strides_[0] +
|
||||
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
|
||||
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
|
||||
{
|
||||
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
|
||||
ck::index_t hi = ho * arg.window_strides_[1] +
|
||||
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
|
||||
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
|
||||
{
|
||||
ck::index_t wi =
|
||||
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
|
||||
ck::index_t wi = wo * arg.window_strides_[2] +
|
||||
x * arg.window_dilations_[2] -
|
||||
arg.in_left_pads_[2];
|
||||
if(di >= 0 &&
|
||||
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
|
||||
hi >= 0 &&
|
||||
@@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
|
||||
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
|
||||
{
|
||||
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
|
||||
ck::index_t di = do_ * arg.window_strides_[0] +
|
||||
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
|
||||
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
|
||||
{
|
||||
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
|
||||
ck::index_t hi = ho * arg.window_strides_[1] +
|
||||
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
|
||||
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
|
||||
{
|
||||
ck::index_t wi =
|
||||
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
|
||||
ck::index_t wi = wo * arg.window_strides_[2] +
|
||||
x * arg.window_dilations_[2] -
|
||||
arg.in_left_pads_[2];
|
||||
if(di >= 0 &&
|
||||
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
|
||||
hi >= 0 &&
|
||||
@@ -202,10 +211,12 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
|
||||
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
|
||||
{
|
||||
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
|
||||
ck::index_t hi = ho * arg.window_strides_[0] +
|
||||
y * arg.window_dilations_[0] - arg.in_left_pads_[0];
|
||||
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
|
||||
{
|
||||
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
|
||||
ck::index_t wi = wo * arg.window_strides_[1] +
|
||||
x * arg.window_dilations_[1] - arg.in_left_pads_[1];
|
||||
if(hi >= 0 &&
|
||||
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
|
||||
wi >= 0 &&
|
||||
@@ -308,6 +319,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
Tensor<IndexDataType>& out_indices,
|
||||
const std::vector<ck::index_t>& window_spatial_lengths,
|
||||
const std::vector<ck::index_t>& window_strides,
|
||||
const std::vector<ck::index_t>& window_dilations,
|
||||
const std::vector<ck::index_t>& in_left_pads,
|
||||
const std::vector<ck::index_t>& in_right_pads)
|
||||
{
|
||||
@@ -316,6 +328,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
out_indices,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads};
|
||||
}
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto InOutRank = 4;
|
||||
static constexpr auto WindowRank = 2;
|
||||
|
||||
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
|
||||
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
|
||||
#ifdef __fp16__
|
||||
// FP16
|
||||
void add_device_pool2d_fwd_nhwc_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, false>>>&);
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, AvgOp, false>>>&);
|
||||
|
||||
// FP16 - return index
|
||||
void add_device_pool2d_fwd_nhwc_index_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
// FP32
|
||||
void add_device_pool2d_fwd_nhwc_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, false>>>&);
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, AvgOp, false>>>&);
|
||||
|
||||
// FP32 - return index
|
||||
void add_device_pool2d_fwd_nhwc_index_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
|
||||
#endif
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
ck::ReduceTensorOp ReduceOpId,
|
||||
bool OutputIndex>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
|
||||
WindowRank,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpId,
|
||||
OutputIndex>>
|
||||
{
|
||||
using DeviceOp = DevicePoolFwd<InOutRank,
|
||||
WindowRank,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpId,
|
||||
OutputIndex>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp16__
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_index_f16_instances(op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_index_f32_instances(op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -25,36 +25,38 @@ static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
|
||||
#ifdef __fp16__
|
||||
// FP16
|
||||
void add_device_pool3d_fwd_ndhwc_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, false>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, false>>>&);
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, AvgOp, false>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, AvgOp, false>>>&);
|
||||
|
||||
// FP16 - return index
|
||||
void add_device_pool3d_fwd_ndhwc_index_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
// FP32
|
||||
void add_device_pool3d_fwd_ndhwc_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, false>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, MaxOp, false>>>&);
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, AvgOp, false>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, AvgOp, false>>>&);
|
||||
|
||||
// FP32 - return index
|
||||
void add_device_pool3d_fwd_ndhwc_index_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, MaxOp, true>>>&);
|
||||
#endif
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename InLayout,
|
||||
typename OutLayout,
|
||||
ck::ReduceTensorOp ReduceOpId,
|
||||
bool OutputIndex>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
|
||||
@@ -62,6 +64,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
|
||||
InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
InLayout,
|
||||
OutLayout,
|
||||
ReduceOpId,
|
||||
OutputIndex>>
|
||||
{
|
||||
@@ -70,40 +74,46 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
|
||||
InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
InLayout,
|
||||
OutLayout,
|
||||
ReduceOpId,
|
||||
OutputIndex>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef __fp16__
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>)
|
||||
{
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
#ifdef __fp16__
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs);
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef __fp32__
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs);
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
set(DEVICE_POOL3D_FWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f32_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_pool3d_fwd_instance ${DEVICE_POOL3D_FWD_INSTANCES})
|
||||
@@ -11,7 +11,9 @@ namespace instance {
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, false>>>& instances)
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
|
||||
@@ -11,7 +11,9 @@ namespace instance {
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, false>>>& instances)
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
|
||||
@@ -11,14 +11,18 @@ namespace instance {
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, false>>>& instances)
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_index_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, true>>>& instances)
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, NDHWC, NDHWC, ReduceOpId, true>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{});
|
||||
@@ -11,14 +11,18 @@ namespace instance {
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, false>>>& instances)
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_index_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, true>>>& instances)
|
||||
std::vector<
|
||||
std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, NDHWC, NDHWC, ReduceOpId, true>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I32 = int32_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using NDHWC = ck::tensor_layout::convolution::NDHWC;
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ComputeDataType,
|
||||
ReduceTensorOp ReduceOpId,
|
||||
bool OutputIndex>
|
||||
using device_pool3d_fwd_ndhwc_instances =
|
||||
// clang-format off
|
||||
std::tuple <
|
||||
DevicePool3dFwd_NDHWC_NDHWC<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
|
||||
DevicePool3dFwd_NDHWC_NDHWC<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
|
||||
DevicePool3dFwd_NDHWC_NDHWC<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,14 +0,0 @@
|
||||
set(DEVICE_POOL_FWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f16_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_f16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f32_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_f32_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f32_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_pool_fwd_instance ${DEVICE_POOL_FWD_INSTANCES})
|
||||
@@ -1,23 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,23 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,30 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_index_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, true>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,30 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_index_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, true>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,55 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I32 = int32_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ComputeDataType,
|
||||
ReduceTensorOp ReduceOpId,
|
||||
bool OutputIndex>
|
||||
using device_pool2d_fwd_nhwc_instances =
|
||||
// clang-format off
|
||||
std::tuple <
|
||||
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
|
||||
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
|
||||
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ComputeDataType,
|
||||
ReduceTensorOp ReduceOpId,
|
||||
bool OutputIndex>
|
||||
using device_pool3d_fwd_ndhwc_instances =
|
||||
// clang-format off
|
||||
std::tuple <
|
||||
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
|
||||
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
|
||||
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user