[CK_TILE] Pooling FWD (Lwpck 3683) (#2956)

* Pooling 2D/3D with refernce

* Tests & cleanup

- added test for ppoling
- cleanup
- removed 2d example

* Comment resolution

- README added
- example target name rectified
- appropriate arg description and comments added

* clang-format

* appropriate blocksize calc

* modifications for future indexing addition

- instead of transforming views we now transform the descriptors, so
that the same descriptor can be re-used for index tensor in the future

* some basic fixes

* comment resolutions

* comment resolutions

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Yashvardhan Agarwal
2025-10-09 17:13:26 +03:00
committed by GitHub
parent 9d4bfe3932
commit 7b6451b68e
14 changed files with 1317 additions and 0 deletions

View File

@@ -0,0 +1,147 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename InDataType,
typename ComputeDataType,
typename OutDataType,
typename ReduceOp,
typename TensorShape,
typename WindowShape>
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
HostTensor<OutDataType>& output,
PoolKernelArgs<TensorShape, WindowShape> kargs,
ReduceOp reduce_op)
{
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<1>{});
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<2>{});
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<3>{});
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<1>{});
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<2>{});
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<0>{});
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<1>{});
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<0>{});
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<1>{});
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<0>{});
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<1>{});
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<0>{});
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<1>{});
// Right padding is handled implicitly by bounds checking
auto f = [&](auto n, auto ho, auto wo, auto c) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
for(ck_tile::index_t y = 0; y < Y; ++y)
{
// Calculate input height index with stride, dilation, and padding
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
for(ck_tile::index_t x = 0; x < X; ++x)
{
// Calculate input width index with stride, dilation, and padding
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
{
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
v_acc = reduce_op(v_acc, v_in);
}
// For positions outside bounds, we implicitly use identity value
}
}
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
};
// Parallelize over all output dimensions
make_ParallelTensorFunctor(f, N, Ho, Wo, C)(std::thread::hardware_concurrency());
}
template <typename InDataType,
typename ComputeDataType,
typename OutDataType,
typename ReduceOp,
typename TensorShape,
typename WindowShape>
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
HostTensor<OutDataType>& output,
PoolKernelArgs<TensorShape, WindowShape> kargs,
ReduceOp reduce_op)
{
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
const ck_tile::index_t D = kargs.input_shape.at(ck_tile::number<1>{});
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<2>{});
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<3>{});
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<4>{});
const ck_tile::index_t Do = kargs.output_shape.at(ck_tile::number<1>{});
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<2>{});
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<3>{});
const ck_tile::index_t Z = kargs.window_lengths.at(ck_tile::number<0>{});
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<1>{});
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<2>{});
const ck_tile::index_t Sz = kargs.window_strides.at(ck_tile::number<0>{});
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<1>{});
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<2>{});
const ck_tile::index_t Dz = kargs.window_dilations.at(ck_tile::number<0>{});
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<1>{});
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<2>{});
const ck_tile::index_t LeftPz = kargs.input_left_pads.at(ck_tile::number<0>{});
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<1>{});
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<2>{});
// Right padding is handled implicitly by bounds checking
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
for(ck_tile::index_t z = 0; z < Z; ++z)
{
// Calculate input depth index with stride, dilation, and padding
ck_tile::index_t di = do_ * Sz + z * Dz - LeftPz;
for(ck_tile::index_t y = 0; y < Y; ++y)
{
// Calculate input height index with stride, dilation, and padding
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
for(ck_tile::index_t x = 0; x < X; ++x)
{
// Calculate input width index with stride, dilation, and padding
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
{
const ComputeDataType v_in =
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
v_acc = reduce_op(v_acc, v_in);
}
// For positions outside bounds, we implicitly use identity value
}
}
}
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
};
// Parallelize over all output dimensions
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
}
} // namespace ck_tile