Merge commit '3052d7c9e6972d5ea7d2225ab78e45554ba70efd' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-29 08:15:15 +00:00
parent e571490afc
commit 6f6c855c0e
13 changed files with 860 additions and 99 deletions

View File

@@ -7,17 +7,21 @@
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
#include <thread>
#include <cmath>
namespace ck_tile {
template <typename InDataType,
typename ComputeDataType,
typename OutDataType,
typename IndexDataType,
typename ReduceOp,
typename TensorShape,
typename WindowShape>
typename WindowShape,
bool OutputIndex = false>
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
HostTensor<OutDataType>& output,
HostTensor<IndexDataType>& output_index,
PoolKernelArgs<TensorShape, WindowShape> kargs,
ReduceOp reduce_op)
{
@@ -45,6 +49,8 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
auto f = [&](auto n, auto ho, auto wo, auto c) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
for(ck_tile::index_t y = 0; y < Y; ++y)
{
// Calculate input height index with stride, dilation, and padding
@@ -58,13 +64,32 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
{
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
v_acc = reduce_op(v_acc, v_in);
if constexpr(OutputIndex)
{
IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c);
bool changed = false;
v_acc = reduce_op(v_acc, v_in, changed);
if(changed)
{
current_index = flat_index;
}
}
else
{
v_acc = reduce_op(v_acc, v_in);
}
}
// For positions outside bounds, we implicitly use identity value
}
}
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
if constexpr(OutputIndex)
{
output_index(n, ho, wo, c) = current_index;
}
};
// Parallelize over all output dimensions
@@ -74,11 +99,14 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
template <typename InDataType,
typename ComputeDataType,
typename OutDataType,
typename IndexDataType,
typename ReduceOp,
typename TensorShape,
typename WindowShape>
typename WindowShape,
bool OutputIndex = false>
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
HostTensor<OutDataType>& output,
HostTensor<IndexDataType>& output_index,
PoolKernelArgs<TensorShape, WindowShape> kargs,
ReduceOp reduce_op)
{
@@ -112,6 +140,8 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
for(ck_tile::index_t z = 0; z < Z; ++z)
{
// Calculate input depth index with stride, dilation, and padding
@@ -131,7 +161,22 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
{
const ComputeDataType v_in =
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
v_acc = reduce_op(v_acc, v_in);
if constexpr(OutputIndex)
{
IndexDataType flat_index =
input.GetOffsetFromMultiIndex(n, di, hi, wi, c);
bool changed = false;
v_acc = reduce_op(v_acc, v_in, changed);
if(changed)
{
current_index = flat_index;
}
}
else
{
v_acc = reduce_op(v_acc, v_in);
}
}
// For positions outside bounds, we implicitly use identity value
}
@@ -139,10 +184,15 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
}
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
if constexpr(OutputIndex)
{
output_index(n, do_, ho, wo, c) = current_index;
}
};
// Parallelize over all output dimensions
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
}
} // namespace ck_tile