[CK_TILE] Add indexing to pooling operator (Lwpck 3892) (#3013)

* Add indexing support to pooling operator

- Add IndexDataType template parameter to pooling problem and kernel
definitions

- Enable pooling kernel to output indices of selected elements during
max/absmax pooling

- Add overloaded operators for Max and AbsMax that track when values
change using bool changed parameter

-  Support optional index buffer allocation and management in device
memory

- Modify BlockReduce2d classes to handle index tensors alongside value
tensors

-  Add separate shared memory allocation for index data in cross-warp
reductions

- Create validate_pool_indices function to verify index correctness

- Modify pool3d.cpp example to demonstrate index output functionality

- Add tests for index output

* fixes

* Refactor BlockReduce2D functions to get rid auxiliary private types.

* comment resolutions and some changes to block_reduce2d

- index reference implementation improved
- reduce_operator.hpp cleanedup
- updated the block_reduce2d.hpp to have index calculation for
BlockReduce2dLinearCrossWarpSync as well

* conditionally used variable declaration improvement

- the conditionally used vairbales are used only when indexing is
enabled. To inform the compiler that they may be unused and declare them
with least size possible. This may allow it to be optimized compared to
the previous declarations

* comment resolutions

* lexical ordering of the indicies

- introduced accumulate methods that handle the intermediate steps if
needed to order the indexes

* add reduce_operator_accumulate.hpp to core.hpp

---------

Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
This commit is contained in:
Yashvardhan Agarwal
2025-10-29 09:58:04 +02:00
committed by GitHub
parent 7c6430eca0
commit 3052d7c9e6
13 changed files with 860 additions and 99 deletions

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
namespace ck_tile {
@@ -50,6 +51,53 @@ struct BlockReduce2d
CK_TILE_DEVICE constexpr BlockReduce2d() {}
private:
template <bool kProcessIndex,
typename XDistributedTensor_,
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc,
typename IndexCalculatorFunc,
typename ReducePacksPerXDim>
CK_TILE_DEVICE void reduce_impl(const XDistributedTensor_& x_tensor,
YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
const ReduceFunc& reduce_func,
const IndexCalculatorFunc& index_calculator,
ReducePacksPerXDim)
{
sweep_tile<XDistributedTensor_>(
[&](auto... idx_) {
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
(..., [&](auto idx) {
auto val = ck_tile::type_convert<ComputeDataType>(x_tensor[idx]);
if constexpr(kProcessIndex)
{
const auto x_indices = get_x_indices_from_distributed_indices(
XDistributedTensor_::get_tile_distribution(), idx);
const auto new_idx = index_calculator(x_indices);
auto current_idx = y_index_tensor(idx_0);
AccumulateWithIndex{}(
reduce_func, y_tensor(idx_0), current_idx, val, new_idx);
y_index_tensor(idx_0) =
type_convert<typename YIndexDistributedTensor_::DataType>(current_idx);
}
else
{
Accumulate{}(reduce_func, y_tensor(idx_0), val);
}
}(idx_));
},
ReducePacksPerXDim{});
}
public:
// Overload for non-index tracking
template <
typename XDistributedTensor_,
typename YDistributedTensor_,
@@ -61,13 +109,36 @@ struct BlockReduce2d
const ReduceFunc& reduce_func,
ReducePacksPerXDim = {})
{
sweep_tile<XDistributedTensor_>(
[&](auto... idx_) {
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
y_tensor(idx_0) = reduce_func(
y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
},
reduce_impl<false>(
x_tensor,
y_tensor,
y_tensor, // dummy
reduce_func,
[](auto) { return 0; }, // dummy
ReducePacksPerXDim{});
}
// Overload for index tracking
template <typename XDistributedTensor_,
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc,
typename IndexCalculatorFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
const ReduceFunc& reduce_func,
const IndexCalculatorFunc& index_calculator,
ReducePacksPerXDim = {})
{
reduce_impl<Problem::kOutputIndex>(x_tensor,
y_tensor,
y_index_tensor,
reduce_func,
index_calculator,
ReducePacksPerXDim{});
}
#if 0
constexpr auto I0 = number<0>{};
@@ -90,7 +161,6 @@ struct BlockReduce2d
y_tensor(y_dstr_idx) = y;
});
#endif
}
template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeYBlockTile()
@@ -111,6 +181,25 @@ struct BlockReduce2d
return tensor;
}
template <typename XDistributedTensor_, typename IndexDataType = index_t>
CK_TILE_DEVICE static auto MakeYIndexBlockTile()
{
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
// FIXME: hard coded to reduce 2nd axis
constexpr auto reduce_dims = sequence<1>{};
constexpr auto dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
XDistributedTensor_::get_tile_distribution()
.get_static_tile_distribution_encoding(),
reduce_dims));
auto tensor = make_static_distributed_tensor<IndexDataType>(dstr);
return tensor;
}
// uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
// e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
template <typename XDistributedTensor_,
@@ -135,8 +224,14 @@ struct BlockReduce2dSync
{
using Problem = remove_cvref_t<Problem_>;
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
private:
template <bool kProcessIndex,
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
const ReduceFunc& reduce_func)
{
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
@@ -157,6 +252,14 @@ struct BlockReduce2dSync
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local = y_tensor.get_thread_buffer()[i];
using IndexDataType = typename YIndexDistributedTensor_::DataType;
IndexDataType idx_local{};
if constexpr(kProcessIndex)
{
idx_local = y_index_tensor.get_thread_buffer()[i];
}
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
@@ -183,15 +286,46 @@ struct BlockReduce2dSync
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
v_local = reduce_func(v_local, v_remote);
if constexpr(kProcessIndex)
{
const auto idx_remote = warp_shuffle(idx_local, src_lane);
AccumulateWithIndex{}(
reduce_func, v_local, idx_local, v_remote, idx_remote);
}
else
{
Accumulate{}(reduce_func, v_local, v_remote);
}
});
}
});
// TODO - Do we need to broadcast to other lane?
y_tensor.get_thread_buffer()(i) = v_local;
if constexpr(kProcessIndex)
{
y_index_tensor.get_thread_buffer()(i) = idx_local;
}
});
}
public:
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
{
reduce_impl<false>(y_tensor, y_tensor, reduce_func);
}
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
const ReduceFunc& reduce_func)
{
reduce_impl<Problem::kOutputIndex>(y_tensor, y_index_tensor, reduce_func);
}
};
// BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
@@ -250,15 +384,39 @@ struct BlockReduce2dCrossWarpSync
return num_warps * thread_buf_size * sizeof(DataType);
}
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
// return in byte - separate shared memory size calculation for indices
template <typename YIndexDistributedTensor_>
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
{
using DataType = typename YDistributedTensor_::DataType;
using IndexDataType = typename YIndexDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
return num_warps * thread_buf_size * sizeof(IndexDataType);
}
private:
template <bool kProcessIndex,
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
void* smem,
void* smem_indices_ptr,
const ReduceFunc& reduce_func)
{
using DataType = typename YDistributedTensor_::DataType;
using IndexDataType = typename YIndexDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
IndexDataType* smem_indices = nullptr;
if constexpr(kProcessIndex)
{
smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
}
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
@@ -275,6 +433,11 @@ struct BlockReduce2dCrossWarpSync
static_for<0, thread_buf_size, 1>{}([&](auto i) {
// Store the i-th element of this warp's thread_buffer into SMEM
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
if constexpr(kProcessIndex)
{
smem_indices[smem_offset + i * num_warps] =
y_index_tensor.get_thread_buffer()[i];
}
});
}
block_sync_lds();
@@ -282,10 +445,19 @@ struct BlockReduce2dCrossWarpSync
// We let each warp holds a duplication to do reduction.
const index_t local_warp_id = warp_id / num_reduce_warps;
const index_t local_smem_os = local_warp_id * num_reduce_warps;
static_for<0, thread_buf_size, 1>{}([&](auto i) {
DataType v[num_reduce_warps];
static_for<0, num_reduce_warps, 1>{}(
[&](auto idx) { v[idx] = smem_ptr[i * num_warps + local_smem_os + idx]; });
[[maybe_unused]] std::
conditional_t<kProcessIndex, IndexDataType[num_reduce_warps], IndexDataType> idx_v;
static_for<0, num_reduce_warps, 1>{}([&](auto idx) {
v[idx] = smem_ptr[i * num_warps + local_smem_os + idx];
if constexpr(kProcessIndex)
{
idx_v[idx] = smem_indices[i * num_warps + local_smem_os + idx];
}
});
static_assert(is_power_of_two_integer(num_reduce_warps),
"wrong! only support power of 2 reduction");
@@ -299,14 +471,44 @@ struct BlockReduce2dCrossWarpSync
constexpr index_t i1 = idx_ + stride;
if constexpr(i1 < num_reduce_warps)
{
v[i0] = reduce_func(v[i0], v[i1]);
if constexpr(kProcessIndex)
{
AccumulateWithIndex{}(reduce_func, v[i0], idx_v[i0], v[i1], idx_v[i1]);
}
else
{
Accumulate{}(reduce_func, v[i0], v[i1]);
}
}
});
});
y_tensor.get_thread_buffer()(i) = v[0];
if constexpr(kProcessIndex)
{
y_index_tensor.get_thread_buffer()(i) = idx_v[0];
}
});
}
public:
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
{
reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
}
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
void* smem,
void* smem_indices,
const ReduceFunc& reduce_func)
{
reduce_impl<Problem::kOutputIndex>(
y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
}
};
template <typename Problem_, typename Policy_ = void>
@@ -364,15 +566,39 @@ struct BlockReduce2dLinearCrossWarpSync
return num_warps * thread_buf_size * sizeof(DataType);
}
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
// return in byte - separate shared memory size calculation for indices
template <typename YIndexDistributedTensor_>
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
{
using DataType = typename YDistributedTensor_::DataType;
using IndexDataType = typename YIndexDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
return num_warps * thread_buf_size * sizeof(IndexDataType);
}
private:
template <bool kProcessIndex,
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
void* smem,
void* smem_indices_ptr,
const ReduceFunc& reduce_func)
{
using DataType = typename YDistributedTensor_::DataType;
using IndexDataType = typename YIndexDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
IndexDataType* smem_indices = nullptr;
if constexpr(kProcessIndex)
{
smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
}
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
@@ -388,6 +614,11 @@ struct BlockReduce2dLinearCrossWarpSync
{
static_for<0, thread_buf_size, 1>{}([&](auto i) {
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
if constexpr(kProcessIndex)
{
smem_indices[smem_offset + i * num_warps] =
y_index_tensor.get_thread_buffer()[i];
}
});
}
block_sync_lds();
@@ -395,31 +626,86 @@ struct BlockReduce2dLinearCrossWarpSync
// load from smem. here we let everythread to do compute :)
index_t local_warp_id = warp_id / num_reduce_warps;
index_t local_smem_os = local_warp_id * num_reduce_warps;
DataType all_scratch[thread_buf_size * num_reduce_warps];
[[maybe_unused]] std::conditional_t<kProcessIndex,
IndexDataType[thread_buf_size * num_reduce_warps],
IndexDataType> all_indices;
// Load data from shared memory
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
all_scratch[i_0 * num_reduce_warps + i_1] =
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
if constexpr(kProcessIndex)
{
all_indices[i_0 * num_reduce_warps + i_1] =
smem_indices[i_0 * num_warps + local_smem_os + i_1];
}
});
});
block_sync_lds(); // TODO: we don't need sync here
// Perform reduction
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
// TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps];
IndexDataType idx_local{};
if constexpr(kProcessIndex)
{
idx_local = all_indices[i_0 * num_reduce_warps];
}
// further reduce mean/var
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{};
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
// reduce
v_local = reduce_func(v_local, v_remote);
if constexpr(kProcessIndex)
{
const IndexDataType idx_remote = all_indices[i_0 * num_reduce_warps + i_1];
bool changed = false;
v_local = reduce_func(v_local, v_remote, changed);
if(changed)
{
idx_local = idx_remote;
}
}
else
{
v_local = reduce_func(v_local, v_remote);
}
});
y_tensor.get_thread_buffer()(i_0) = v_local;
if constexpr(kProcessIndex)
{
y_index_tensor.get_thread_buffer()(i_0) = idx_local;
}
});
}
public:
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
{
reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
}
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
void* smem,
void* smem_indices,
const ReduceFunc& reduce_func)
{
reduce_impl<Problem::kOutputIndex>(
y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
}
};
} // namespace ck_tile

View File

@@ -7,12 +7,17 @@
namespace ck_tile {
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_>
template <typename XDataType_,
typename ComputeDataType_,
typename BlockShape_,
bool OutputIndex_ = false>
struct BlockReduce2dProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kOutputIndex = OutputIndex_;
};
} // namespace ck_tile

View File

@@ -32,7 +32,8 @@ struct Reduce2dDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
return BlockReduce2d<P_>{};
}
@@ -41,7 +42,8 @@ struct Reduce2dDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
return BlockReduce2dSync<P_>{};
}
@@ -50,7 +52,8 @@ struct Reduce2dDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
return BlockReduce2dCrossWarpSync<P_>{};
}
@@ -61,7 +64,8 @@ struct Reduce2dDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile =
@@ -76,5 +80,23 @@ struct Reduce2dDefaultPolicy
return 1; // zero size arrays are an extension
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::kOutputIndex>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile = decltype(make_static_distributed_tensor<typename Problem::XDataType>(
MakeXBlockTileDistribution<Problem>()));
using y_index_block_tile =
decltype(block_reduce2d::template MakeYIndexBlockTile<x_block_tile, index_t>());
return GetBlockReduce2dCrossWarpSync<Problem>()
.template GetIndicesSmemSize<y_index_block_tile>();
}
};
} // namespace ck_tile

View File

@@ -11,7 +11,8 @@ template <typename XDataType_,
typename ComputeDataType_,
typename YDataType_,
typename BlockShape_,
typename ReduceOp_>
typename ReduceOp_,
bool OutputIndex_ = false>
struct Reduce2dProblem
{
using XDataType = remove_cvref_t<XDataType_>;
@@ -20,6 +21,7 @@ struct Reduce2dProblem
using BlockShape = remove_cvref_t<BlockShape_>;
using ReduceOp = ReduceOp_;
static constexpr bool kOutputIndex = OutputIndex_;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
};