mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Pooling FWD (Lwpck 3683) (#2956)
* Pooling 2D/3D with refernce * Tests & cleanup - added test for ppoling - cleanup - removed 2d example * Comment resolution - README added - example target name rectified - appropriate arg description and comments added * clang-format * appropriate blocksize calc * modifications for future indexing addition - instead of transforming views we now transform the descriptors, so that the same descriptor can be re-used for index tensor in the future * some basic fixes * comment resolutions * comment resolutions --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9d4bfe3932
commit
7b6451b68e
147
include/ck_tile/host/reference/reference_pool.hpp
Normal file
147
include/ck_tile/host/reference/reference_pool.hpp
Normal file
@@ -0,0 +1,147 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<2>{});
|
||||
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<3>{});
|
||||
|
||||
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<1>{});
|
||||
|
||||
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<1>{});
|
||||
|
||||
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<1>{});
|
||||
|
||||
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<1>{});
|
||||
// Right padding is handled implicitly by bounds checking
|
||||
|
||||
auto f = [&](auto n, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
for(ck_tile::index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
// Calculate input height index with stride, dilation, and padding
|
||||
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
|
||||
|
||||
for(ck_tile::index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// Calculate input width index with stride, dilation, and padding
|
||||
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
|
||||
|
||||
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
|
||||
{
|
||||
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
}
|
||||
|
||||
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
make_ParallelTensorFunctor(f, N, Ho, Wo, C)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t D = kargs.input_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<2>{});
|
||||
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<3>{});
|
||||
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<4>{});
|
||||
|
||||
const ck_tile::index_t Do = kargs.output_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<2>{});
|
||||
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<3>{});
|
||||
|
||||
const ck_tile::index_t Z = kargs.window_lengths.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t Sz = kargs.window_strides.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t Dz = kargs.window_dilations.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t LeftPz = kargs.input_left_pads.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<2>{});
|
||||
// Right padding is handled implicitly by bounds checking
|
||||
|
||||
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
for(ck_tile::index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
// Calculate input depth index with stride, dilation, and padding
|
||||
ck_tile::index_t di = do_ * Sz + z * Dz - LeftPz;
|
||||
|
||||
for(ck_tile::index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
// Calculate input height index with stride, dilation, and padding
|
||||
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
|
||||
|
||||
for(ck_tile::index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// Calculate input width index with stride, dilation, and padding
|
||||
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
|
||||
|
||||
if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
|
||||
{
|
||||
const ComputeDataType v_in =
|
||||
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user