mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
General 2D Reduction Kernel (#2535)
* General 2D Reduction Kernel * Move the reduction kernel from the example * Split the code and add the necessary policy, problem, shape files as per ck_tile convention * Add/modify the headers * Modified the example to work with the 'new' kernel * Added tests for the kernel * N-D refernce reduce * Added support for N-D input with transform to 2D * Added padding to support various input sized tensors * Bug fix in the thread buffer constructor * Some comments to explain the reduce2d block kernel * comments resolution * clang-format * comments resolution * clang-format * clang-format * comments resolution * clang-format
This commit is contained in:
committed by
GitHub
parent
2622ff06cb
commit
4750b293fe
@@ -7,20 +7,55 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// BlockReduce2d implements a hierarchical 2D reduction operator that reduces data along the second
|
||||
// dimension using a user-specified reduction function.
|
||||
//
|
||||
// The reduction is performed in a three-stage hierarchical approach:
|
||||
//
|
||||
// STAGE 1: Thread-level reduction (BlockReduce2d)
|
||||
// ===============================================
|
||||
// - Each thread processes multiple elements from the input tensor within its assigned data
|
||||
// partition
|
||||
// - Reduction is performed locally within each thread by iterating over assigned elements
|
||||
// - ReducePacksPerXDim controls how many elements sweep_tile processes in one iteration per
|
||||
// dimension
|
||||
// (e.g., {1,1} = 1 element at a time from each dimension, {2,4} = 2 from dim0, 4 from dim1)
|
||||
// - Results are accumulated into a thread-local output tensor stored in registers
|
||||
// - The output tensor distribution is derived from the input tensor's distribution using
|
||||
// make_reduce_tile_distribution_encoding() to handle dimension reduction
|
||||
//
|
||||
// STAGE 2: Warp-level reduction (BlockReduce2dSync)
|
||||
// ================================================
|
||||
// - Performs inter-thread reduction within each warp
|
||||
// - Uses warp shuffle operations to exchange data between threads in the same warp
|
||||
// - Implements a tree-reduction pattern with power-of-2 stages
|
||||
// - Only reduces along dimensions that map to lane IDs within the warp
|
||||
//
|
||||
// STAGE 3: Cross-warp reduction (BlockReduce2dCrossWarpSync)
|
||||
// ========================================================
|
||||
// - Performs reduction across multiple warps within the same thread block
|
||||
// - Uses shared memory (LDS) to facilitate data exchange between warps
|
||||
// - Each warp's lane-0 thread stores its partial results to shared memory
|
||||
// - All threads participate in loading and reducing data from shared memory
|
||||
// - Implements block-level synchronization to ensure memory consistency
|
||||
|
||||
// BlockReduce2d: Thread-level reduction (Stage 1)
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockReduce2d
|
||||
{
|
||||
// in-thread reduction
|
||||
// Thread-level reduction implementation
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using ComputeDataType = typename Problem::ComputeDataType;
|
||||
|
||||
CK_TILE_DEVICE constexpr BlockReduce2d() {}
|
||||
|
||||
template <typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
|
||||
template <
|
||||
typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename ReducePacksPerXDim =
|
||||
uniform_sequence_gen_t<2, 1>> // {1,1} = process 1 element at a time from each dimension
|
||||
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
|
||||
YDistributedTensor_& y_tensor,
|
||||
const ReduceFunc& reduce_func,
|
||||
@@ -33,6 +68,7 @@ struct BlockReduce2d
|
||||
y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
|
||||
},
|
||||
ReducePacksPerXDim{});
|
||||
|
||||
#if 0
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
@@ -75,6 +111,8 @@ struct BlockReduce2d
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
|
||||
// e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
|
||||
template <typename XDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
|
||||
@@ -91,6 +129,7 @@ struct BlockReduce2d
|
||||
}
|
||||
};
|
||||
|
||||
// BlockReduce2dSync: Warp-level reduction (Stage 2)
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockReduce2dSync
|
||||
{
|
||||
@@ -145,8 +184,15 @@ struct BlockReduce2dSync
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle(v_local, src_lane);
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
// For reduce, use combine_partial_results for operations that require it
|
||||
if constexpr(ReduceFunc::requires_special_combine)
|
||||
{
|
||||
v_local = reduce_func.combine_partial_results(v_local, v_remote);
|
||||
}
|
||||
else
|
||||
{
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -157,6 +203,7 @@ struct BlockReduce2dSync
|
||||
}
|
||||
};
|
||||
|
||||
// BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockReduce2dCrossWarpSync
|
||||
{
|
||||
@@ -263,8 +310,15 @@ struct BlockReduce2dCrossWarpSync
|
||||
constexpr auto i_1 = number<i_1_n1 + 1>{};
|
||||
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
// For reduce, use combine_partial_results for operations that require it
|
||||
if constexpr(ReduceFunc::requires_special_combine)
|
||||
{
|
||||
v_local = reduce_func.combine_partial_results(v_local, v_remote);
|
||||
}
|
||||
else
|
||||
{
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
}
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i_0) = v_local;
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BlockReduce2dDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2d<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile =
|
||||
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
|
||||
|
||||
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user