[CK_TILE] Add indexing to pooling operator (Lwpck 3892) (#3013)

* Add indexing support to pooling operator

- Add IndexDataType template parameter to pooling problem and kernel
definitions

- Enable pooling kernel to output indices of selected elements during
max/absmax pooling

- Add overloaded operators for Max and AbsMax that track when values
change using bool changed parameter

-  Support optional index buffer allocation and management in device
memory

- Modify BlockReduce2d classes to handle index tensors alongside value
tensors

-  Add separate shared memory allocation for index data in cross-warp
reductions

- Create validate_pool_indices function to verify index correctness

- Modify pool3d.cpp example to demonstrate index output functionality

- Add tests for index output

* fixes

* Refactor BlockReduce2D functions to get rid auxiliary private types.

* comment resolutions and some changes to block_reduce2d

- index reference implementation improved
- reduce_operator.hpp cleanedup
- updated the block_reduce2d.hpp to have index calculation for
BlockReduce2dLinearCrossWarpSync as well

* conditionally used variable declaration improvement

- the conditionally used vairbales are used only when indexing is
enabled. To inform the compiler that they may be unused and declare them
with least size possible. This may allow it to be optimized compared to
the previous declarations

* comment resolutions

* lexical ordering of the indicies

- introduced accumulate methods that handle the intermediate steps if
needed to order the indexes

* add reduce_operator_accumulate.hpp to core.hpp

---------

Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
This commit is contained in:
Yashvardhan Agarwal
2025-10-29 09:58:04 +02:00
committed by GitHub
parent 7c6430eca0
commit 3052d7c9e6
13 changed files with 860 additions and 99 deletions

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -18,16 +19,14 @@ struct Add
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return y + x;
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
@@ -46,16 +45,14 @@ struct SquareAdd
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return y + (x * x);
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
@@ -66,48 +63,74 @@ struct SquareAdd
struct Max
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::lowest();
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, x);
}
// Overload with changed flag for index tracking
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
{
T new_max = max(y, x);
if(x > y)
{
changed = true;
}
return new_max;
}
};
struct AbsMax
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::zero();
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, abs(x));
}
// Overload with changed flag for index tracking
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
{
T new_max = max(y, abs(x));
if(abs(x) > y)
{
changed = true;
}
return new_max;
}
};
} // namespace ReduceOp

View File

@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
/// @brief Accumulate with index tracking reductions, provides deterministic first occurring index
struct AccumulateWithIndex
{
template <typename ReduceOp, typename T, typename IndexType>
CK_TILE_HOST_DEVICE void operator()(const ReduceOp& reduce_func,
T& current_value,
IndexType& current_index,
const T& new_value,
const IndexType& new_index) const
{
bool changed = false;
current_value = reduce_func(current_value, new_value, changed);
if(changed)
{
current_index = new_index;
}
else if(new_index < current_index)
{
bool reverse_changed = false;
reduce_func(new_value, current_value, reverse_changed);
if(!reverse_changed)
{
current_index = new_index;
}
}
}
};
struct Accumulate
{
template <typename ReduceOp, typename T>
CK_TILE_HOST_DEVICE void
operator()(const ReduceOp& reduce_func, T& current_value, const T& new_value) const
{
current_value = reduce_func(current_value, new_value);
}
};
} // namespace ck_tile