mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
Merge commit '3052d7c9e6972d5ea7d2225ab78e45554ba70efd' into develop
This commit is contained in:
@@ -7,17 +7,21 @@
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
|
||||
#include <thread>
|
||||
#include <cmath>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
typename WindowShape,
|
||||
bool OutputIndex = false>
|
||||
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
HostTensor<IndexDataType>& output_index,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
@@ -45,6 +49,8 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
auto f = [&](auto n, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
||||
|
||||
for(ck_tile::index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
// Calculate input height index with stride, dilation, and padding
|
||||
@@ -58,13 +64,32 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
|
||||
{
|
||||
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c);
|
||||
bool changed = false;
|
||||
v_acc = reduce_op(v_acc, v_in, changed);
|
||||
if(changed)
|
||||
{
|
||||
current_index = flat_index;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
}
|
||||
|
||||
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
output_index(n, ho, wo, c) = current_index;
|
||||
}
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
@@ -74,11 +99,14 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
typename WindowShape,
|
||||
bool OutputIndex = false>
|
||||
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
HostTensor<IndexDataType>& output_index,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
@@ -112,6 +140,8 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
||||
|
||||
for(ck_tile::index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
// Calculate input depth index with stride, dilation, and padding
|
||||
@@ -131,7 +161,22 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
{
|
||||
const ComputeDataType v_in =
|
||||
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
IndexDataType flat_index =
|
||||
input.GetOffsetFromMultiIndex(n, di, hi, wi, c);
|
||||
bool changed = false;
|
||||
v_acc = reduce_op(v_acc, v_in, changed);
|
||||
if(changed)
|
||||
{
|
||||
current_index = flat_index;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
@@ -139,10 +184,15 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
}
|
||||
|
||||
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
|
||||
output_index(n, do_, ho, wo, c) = current_index;
|
||||
}
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user