mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
General 2D Reduction Kernel (#2535)
* General 2D Reduction Kernel * Move the reduction kernel from the example * Split the code and add the necessary policy, problem, shape files as per ck_tile convention * Add/modify the headers * Modified the example to work with the 'new' kernel * Added tests for the kernel * N-D refernce reduce * Added support for N-D input with transform to 2D * Added padding to support various input sized tensors * Bug fix in the thread buffer constructor * Some comments to explain the reduce2d block kernel * comments resolution * clang-format * comments resolution * clang-format * clang-format * comments resolution * clang-format
This commit is contained in:
committed by
GitHub
parent
2622ff06cb
commit
4750b293fe
@@ -7,20 +7,55 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// BlockReduce2d implements a hierarchical 2D reduction operator that reduces data along the second
|
||||
// dimension using a user-specified reduction function.
|
||||
//
|
||||
// The reduction is performed in a three-stage hierarchical approach:
|
||||
//
|
||||
// STAGE 1: Thread-level reduction (BlockReduce2d)
|
||||
// ===============================================
|
||||
// - Each thread processes multiple elements from the input tensor within its assigned data
|
||||
// partition
|
||||
// - Reduction is performed locally within each thread by iterating over assigned elements
|
||||
// - ReducePacksPerXDim controls how many elements sweep_tile processes in one iteration per
|
||||
// dimension
|
||||
// (e.g., {1,1} = 1 element at a time from each dimension, {2,4} = 2 from dim0, 4 from dim1)
|
||||
// - Results are accumulated into a thread-local output tensor stored in registers
|
||||
// - The output tensor distribution is derived from the input tensor's distribution using
|
||||
// make_reduce_tile_distribution_encoding() to handle dimension reduction
|
||||
//
|
||||
// STAGE 2: Warp-level reduction (BlockReduce2dSync)
|
||||
// ================================================
|
||||
// - Performs inter-thread reduction within each warp
|
||||
// - Uses warp shuffle operations to exchange data between threads in the same warp
|
||||
// - Implements a tree-reduction pattern with power-of-2 stages
|
||||
// - Only reduces along dimensions that map to lane IDs within the warp
|
||||
//
|
||||
// STAGE 3: Cross-warp reduction (BlockReduce2dCrossWarpSync)
|
||||
// ========================================================
|
||||
// - Performs reduction across multiple warps within the same thread block
|
||||
// - Uses shared memory (LDS) to facilitate data exchange between warps
|
||||
// - Each warp's lane-0 thread stores its partial results to shared memory
|
||||
// - All threads participate in loading and reducing data from shared memory
|
||||
// - Implements block-level synchronization to ensure memory consistency
|
||||
|
||||
// BlockReduce2d: Thread-level reduction (Stage 1)
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockReduce2d
|
||||
{
|
||||
// in-thread reduction
|
||||
// Thread-level reduction implementation
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using ComputeDataType = typename Problem::ComputeDataType;
|
||||
|
||||
CK_TILE_DEVICE constexpr BlockReduce2d() {}
|
||||
|
||||
template <typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
|
||||
template <
|
||||
typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename ReducePacksPerXDim =
|
||||
uniform_sequence_gen_t<2, 1>> // {1,1} = process 1 element at a time from each dimension
|
||||
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
|
||||
YDistributedTensor_& y_tensor,
|
||||
const ReduceFunc& reduce_func,
|
||||
@@ -33,6 +68,7 @@ struct BlockReduce2d
|
||||
y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
|
||||
},
|
||||
ReducePacksPerXDim{});
|
||||
|
||||
#if 0
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
@@ -75,6 +111,8 @@ struct BlockReduce2d
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
|
||||
// e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
|
||||
template <typename XDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
|
||||
@@ -91,6 +129,7 @@ struct BlockReduce2d
|
||||
}
|
||||
};
|
||||
|
||||
// BlockReduce2dSync: Warp-level reduction (Stage 2)
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockReduce2dSync
|
||||
{
|
||||
@@ -145,8 +184,15 @@ struct BlockReduce2dSync
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle(v_local, src_lane);
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
// For reduce, use combine_partial_results for operations that require it
|
||||
if constexpr(ReduceFunc::requires_special_combine)
|
||||
{
|
||||
v_local = reduce_func.combine_partial_results(v_local, v_remote);
|
||||
}
|
||||
else
|
||||
{
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -157,6 +203,7 @@ struct BlockReduce2dSync
|
||||
}
|
||||
};
|
||||
|
||||
// BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockReduce2dCrossWarpSync
|
||||
{
|
||||
@@ -263,8 +310,15 @@ struct BlockReduce2dCrossWarpSync
|
||||
constexpr auto i_1 = number<i_1_n1 + 1>{};
|
||||
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
// For reduce, use combine_partial_results for operations that require it
|
||||
if constexpr(ReduceFunc::requires_special_combine)
|
||||
{
|
||||
v_local = reduce_func.combine_partial_results(v_local, v_remote);
|
||||
}
|
||||
else
|
||||
{
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
}
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i_0) = v_local;
|
||||
|
||||
219
include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp
Normal file
219
include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp
Normal file
@@ -0,0 +1,219 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
#include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp"
|
||||
|
||||
// Reduce2d Kernel:
|
||||
// =======================================
|
||||
// This kernel implements a 2D reduction operation that reduces data along the second dimension
|
||||
// of a matrix. The reduction is performed in multiple hierarchical stages.
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = Reduce2dDefaultPolicy>
|
||||
struct Reduce
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
|
||||
private:
|
||||
// Helper function to calculate optimal vector size for input tensor
|
||||
template <typename InputShape, typename ReduceDims>
|
||||
static constexpr index_t CalculateInputVectorSize()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
constexpr index_t memory_vector_size = 16 / sizeof(XDataType);
|
||||
constexpr index_t thread_tile_vector_size = S::ThreadTile_N;
|
||||
|
||||
// Check if innermost reduce dimension is the last dimension (stride 1).
|
||||
constexpr auto innermost_reduce_dim = ReduceDims{}.at(number<ReduceDims{}.size() - 1>{});
|
||||
constexpr bool is_innermost_contiguous = (innermost_reduce_dim == InputShape{}.size() - 1);
|
||||
|
||||
// If innermost reduce dimension is not the last dim (not contiguous), limit vectorization
|
||||
constexpr index_t stride_based_vector_size =
|
||||
is_innermost_contiguous ? ck_tile::min(memory_vector_size, thread_tile_vector_size) : 1;
|
||||
|
||||
return stride_based_vector_size;
|
||||
}
|
||||
|
||||
// Helper function to calculate optimal vector size for output tensor
|
||||
static constexpr index_t CalculateOutputVectorSize()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
constexpr index_t memory_vector_size = 16 / sizeof(YDataType);
|
||||
constexpr index_t thread_tile_vector_size = S::ThreadTile_M;
|
||||
constexpr index_t vector_size = ck_tile::min(memory_vector_size, thread_tile_vector_size);
|
||||
|
||||
return vector_size;
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename InputShape, typename InputStrides, typename KeptDim, typename ReduceDims>
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x,
|
||||
YDataType* p_y,
|
||||
InputShape input_shape,
|
||||
InputStrides input_strides,
|
||||
KeptDim kept_dim,
|
||||
ReduceDims reduce_dims) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
static_assert(kept_dim.size() + reduce_dims.size() == InputShape::size(),
|
||||
"Size of kept dimensions + reduced dimensions must equal input tensor rank");
|
||||
|
||||
// Extract lengths based on kept and reduced dimensions
|
||||
const auto kept_lens = [&]() {
|
||||
return generate_tuple([&](auto I) { return input_shape.at(number<kept_dim.at(I)>{}); },
|
||||
number<kept_dim.size()>{});
|
||||
}();
|
||||
const auto reduce_lens = [&]() {
|
||||
return generate_tuple(
|
||||
[&](auto I) { return input_shape.at(number<reduce_dims.at(I)>{}); },
|
||||
number<reduce_dims.size()>{});
|
||||
}();
|
||||
|
||||
const auto kept_merge_transform = make_merge_transform(kept_lens);
|
||||
const auto reduce_merge_transform = make_merge_transform(reduce_lens);
|
||||
|
||||
auto reduce_func = typename Problem::ReduceOp{};
|
||||
const XDataType custom_padding_value =
|
||||
type_convert<XDataType>(reduce_func.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
// Calculate optimal vector size for input tensor
|
||||
constexpr auto x_tensor_vector_size = CalculateInputVectorSize<InputShape, ReduceDims>();
|
||||
|
||||
// Create input tensor view with custom padding value
|
||||
auto desc = make_naive_tensor_descriptor(
|
||||
input_shape, input_strides, number<x_tensor_vector_size>{}, number<1>{});
|
||||
|
||||
// Create buffer view with custom padding value
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_x, desc.get_element_space_size(), custom_padding_value);
|
||||
|
||||
// Create tensor view with custom padding
|
||||
const auto x_tensor = tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
const auto transformed_x_tensor = pad_tensor_view(
|
||||
transform_tensor_view(x_tensor,
|
||||
make_tuple(kept_merge_transform, reduce_merge_transform),
|
||||
make_tuple(kept_dim, reduce_dims),
|
||||
make_tuple(sequence<0>{}, sequence<1>{})),
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
sequence<0, 1>{});
|
||||
|
||||
// Calculate strides for output tensor based on its own dimensions
|
||||
const auto kept_strides = [&]() {
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
// Calculate stride for dimension I as product of all following dimensions
|
||||
index_t stride = 1;
|
||||
static_for<I + 1, kept_dim.size(), 1>{}(
|
||||
[&](auto J) { stride *= kept_lens.at(number<J>{}); });
|
||||
return stride;
|
||||
},
|
||||
number<kept_dim.size()>{});
|
||||
}();
|
||||
|
||||
// Calculate optimal vector size for output tensor
|
||||
constexpr auto y_tensor_vector_size = CalculateOutputVectorSize();
|
||||
|
||||
const auto y_m = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, kept_lens, kept_strides, number<y_tensor_vector_size>{}, number<1>{});
|
||||
|
||||
// Transform output tensor to 1D merged view
|
||||
// This creates a view compatible with the 2D reduction pattern
|
||||
const auto y_merged = transform_tensor_view(
|
||||
y_m,
|
||||
make_tuple(kept_merge_transform),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, kept_dim.size(), 1>::type{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
auto x_window = make_tile_window(transformed_x_tensor,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto y_window = make_tile_window(y_merged, make_tuple(number<S::Block_M>{}), {iM});
|
||||
|
||||
__shared__ char smem[Policy::template GetSmemSize<Problem>()];
|
||||
|
||||
// Get the merged dimension size from the transformed tensor
|
||||
const auto merged_reduce_len =
|
||||
transformed_x_tensor.get_tensor_descriptor().get_lengths().at(number<1>{});
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(merged_reduce_len, S::Block_N));
|
||||
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(load_tile(x_window));
|
||||
auto y_compute = block_reduce2d.template MakeYBlockTile<XTensorType>();
|
||||
set_tile(y_compute, reduce_func.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
block_reduce2d(x, y_compute, reduce_func);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_compute, reduce_func);
|
||||
block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func);
|
||||
|
||||
store_tile(y_window, cast_tile<YDataType>(y_compute));
|
||||
}
|
||||
|
||||
/// @brief Validates if the given arguments are supported by the 2D reduction kernel.
|
||||
///
|
||||
/// @param y_continous_dim Size of the continuous dimension of the output tensor.
|
||||
/// Must be a multiple of ThreadTile_N for proper thread mapping.
|
||||
///
|
||||
/// @param input_strides The stride configuration of the input tensor.
|
||||
/// The last stride must be 1 to ensure contiguous memory access
|
||||
/// and enable efficient vectorized loads.
|
||||
///
|
||||
/// @return true if the arguments are supported, false otherwise.
|
||||
/// Error messages are logged when CK_TILE_LOGGING is enabled.
|
||||
///
|
||||
/// @note Requirements:
|
||||
/// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution)
|
||||
/// - input_strides[-1] == 1 (for contiguous memory access)
|
||||
CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, auto input_strides)
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
if(y_continous_dim % S::ThreadTile_N != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Total reduction size should be a multiple of ThreadTile_N!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(input_strides.at(number<input_strides.size() - 1>{}) != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Input tensor's last stride must be 1 to support correct vector access!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BlockReduce2dDefaultPolicy
|
||||
struct Reduce2dDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
@@ -18,8 +18,9 @@ struct BlockReduce2dDefaultPolicy
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<
|
||||
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>,
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
27
include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp
Normal file
27
include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename BlockShape_,
|
||||
typename ReduceOp_>
|
||||
struct Reduce2dProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ReduceOp = ReduceOp_;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
37
include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp
Normal file
37
include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockWarps, // num warps along seq<M, N>
|
||||
typename BlockTile, // block size, seq<M, N>
|
||||
typename WarpTile, // warp size, seq<M, N>
|
||||
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
|
||||
struct Reduce2dShape
|
||||
{
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user