mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
* Add indexing support to pooling operator - Add IndexDataType template parameter to pooling problem and kernel definitions - Enable pooling kernel to output indices of selected elements during max/absmax pooling - Add overloaded operators for Max and AbsMax that track when values change using bool changed parameter - Support optional index buffer allocation and management in device memory - Modify BlockReduce2d classes to handle index tensors alongside value tensors - Add separate shared memory allocation for index data in cross-warp reductions - Create validate_pool_indices function to verify index correctness - Modify pool3d.cpp example to demonstrate index output functionality - Add tests for index output * fixes * Refactor BlockReduce2D functions to get rid auxiliary private types. * comment resolutions and some changes to block_reduce2d - index reference implementation improved - reduce_operator.hpp cleanedup - updated the block_reduce2d.hpp to have index calculation for BlockReduce2dLinearCrossWarpSync as well * conditionally used variable declaration improvement - the conditionally used vairbales are used only when indexing is enabled. To inform the compiler that they may be unused and declare them with least size possible. This may allow it to be optimized compared to the previous declarations * comment resolutions * lexical ordering of the indicies - introduced accumulate methods that handle the intermediate steps if needed to order the indexes * add reduce_operator_accumulate.hpp to core.hpp --------- Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
199 lines
8.3 KiB
C++
199 lines
8.3 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host/host_tensor.hpp"
|
|
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
|
|
#include <thread>
|
|
#include <cmath>
|
|
|
|
namespace ck_tile {
|
|
|
|
template <typename InDataType,
|
|
typename ComputeDataType,
|
|
typename OutDataType,
|
|
typename IndexDataType,
|
|
typename ReduceOp,
|
|
typename TensorShape,
|
|
typename WindowShape,
|
|
bool OutputIndex = false>
|
|
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
|
HostTensor<OutDataType>& output,
|
|
HostTensor<IndexDataType>& output_index,
|
|
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
|
ReduceOp reduce_op)
|
|
{
|
|
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
|
|
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<1>{});
|
|
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<2>{});
|
|
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<3>{});
|
|
|
|
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<1>{});
|
|
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<2>{});
|
|
|
|
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<0>{});
|
|
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<1>{});
|
|
|
|
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<0>{});
|
|
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<1>{});
|
|
|
|
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<0>{});
|
|
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<1>{});
|
|
|
|
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<0>{});
|
|
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<1>{});
|
|
// Right padding is handled implicitly by bounds checking
|
|
|
|
auto f = [&](auto n, auto ho, auto wo, auto c) {
|
|
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
|
|
|
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
|
|
|
for(ck_tile::index_t y = 0; y < Y; ++y)
|
|
{
|
|
// Calculate input height index with stride, dilation, and padding
|
|
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
|
|
|
|
for(ck_tile::index_t x = 0; x < X; ++x)
|
|
{
|
|
// Calculate input width index with stride, dilation, and padding
|
|
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
|
|
|
|
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
|
|
{
|
|
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
|
|
|
|
if constexpr(OutputIndex)
|
|
{
|
|
IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c);
|
|
bool changed = false;
|
|
v_acc = reduce_op(v_acc, v_in, changed);
|
|
if(changed)
|
|
{
|
|
current_index = flat_index;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
v_acc = reduce_op(v_acc, v_in);
|
|
}
|
|
}
|
|
// For positions outside bounds, we implicitly use identity value
|
|
}
|
|
}
|
|
|
|
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
|
|
|
if constexpr(OutputIndex)
|
|
{
|
|
output_index(n, ho, wo, c) = current_index;
|
|
}
|
|
};
|
|
|
|
// Parallelize over all output dimensions
|
|
make_ParallelTensorFunctor(f, N, Ho, Wo, C)(std::thread::hardware_concurrency());
|
|
}
|
|
|
|
template <typename InDataType,
|
|
typename ComputeDataType,
|
|
typename OutDataType,
|
|
typename IndexDataType,
|
|
typename ReduceOp,
|
|
typename TensorShape,
|
|
typename WindowShape,
|
|
bool OutputIndex = false>
|
|
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
|
HostTensor<OutDataType>& output,
|
|
HostTensor<IndexDataType>& output_index,
|
|
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
|
ReduceOp reduce_op)
|
|
{
|
|
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
|
|
const ck_tile::index_t D = kargs.input_shape.at(ck_tile::number<1>{});
|
|
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<2>{});
|
|
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<3>{});
|
|
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<4>{});
|
|
|
|
const ck_tile::index_t Do = kargs.output_shape.at(ck_tile::number<1>{});
|
|
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<2>{});
|
|
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<3>{});
|
|
|
|
const ck_tile::index_t Z = kargs.window_lengths.at(ck_tile::number<0>{});
|
|
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<1>{});
|
|
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<2>{});
|
|
|
|
const ck_tile::index_t Sz = kargs.window_strides.at(ck_tile::number<0>{});
|
|
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<1>{});
|
|
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<2>{});
|
|
|
|
const ck_tile::index_t Dz = kargs.window_dilations.at(ck_tile::number<0>{});
|
|
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<1>{});
|
|
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<2>{});
|
|
|
|
const ck_tile::index_t LeftPz = kargs.input_left_pads.at(ck_tile::number<0>{});
|
|
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<1>{});
|
|
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<2>{});
|
|
// Right padding is handled implicitly by bounds checking
|
|
|
|
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
|
|
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
|
|
|
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
|
|
|
for(ck_tile::index_t z = 0; z < Z; ++z)
|
|
{
|
|
// Calculate input depth index with stride, dilation, and padding
|
|
ck_tile::index_t di = do_ * Sz + z * Dz - LeftPz;
|
|
|
|
for(ck_tile::index_t y = 0; y < Y; ++y)
|
|
{
|
|
// Calculate input height index with stride, dilation, and padding
|
|
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
|
|
|
|
for(ck_tile::index_t x = 0; x < X; ++x)
|
|
{
|
|
// Calculate input width index with stride, dilation, and padding
|
|
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
|
|
|
|
if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
|
|
{
|
|
const ComputeDataType v_in =
|
|
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
|
|
|
|
if constexpr(OutputIndex)
|
|
{
|
|
IndexDataType flat_index =
|
|
input.GetOffsetFromMultiIndex(n, di, hi, wi, c);
|
|
bool changed = false;
|
|
v_acc = reduce_op(v_acc, v_in, changed);
|
|
if(changed)
|
|
{
|
|
current_index = flat_index;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
v_acc = reduce_op(v_acc, v_in);
|
|
}
|
|
}
|
|
// For positions outside bounds, we implicitly use identity value
|
|
}
|
|
}
|
|
}
|
|
|
|
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
|
|
|
if constexpr(OutputIndex)
|
|
{
|
|
|
|
output_index(n, do_, ho, wo, c) = current_index;
|
|
}
|
|
};
|
|
|
|
// Parallelize over all output dimensions
|
|
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
|
|
}
|
|
} // namespace ck_tile
|