General 2D Reduction Kernel (#2535)

* General 2D Reduction Kernel

* Move the reduction kernel from the example
* Split the code and add the necessary policy, problem, shape files as
per ck_tile convention
* Add/modify the headers
* Modified the example to work with the 'new' kernel
* Added tests for the kernel
* N-D refernce reduce
* Added support for N-D input with transform to 2D
* Added padding to support various input sized tensors
* Bug fix in the thread buffer constructor
* Some comments to explain the reduce2d block kernel

* comments resolution

* clang-format

* comments resolution

* clang-format

* clang-format

* comments resolution

* clang-format

[ROCm/composable_kernel commit: 4750b293fe]
This commit is contained in:
Yashvardhan Agarwal
2025-08-06 16:36:59 +03:00
committed by GitHub
parent b7659e284a
commit b6f0e98da6
14 changed files with 905 additions and 199 deletions

View File

@@ -42,7 +42,11 @@ struct thread_buffer {
// TODO: this ctor can't ignore
CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {}
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{} {
static_for<0, N, 1>{}(
[&](auto i) { data[i] = o; }
);
}
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE auto & get() {return data; }

View File

@@ -26,7 +26,8 @@ struct Add
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t>>>
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
@@ -34,6 +35,8 @@ struct Add
return type_convert<T>(y_ + x_);
}
static constexpr bool requires_special_combine = false;
};
struct SquareAdd
@@ -51,13 +54,47 @@ struct SquareAdd
{
return y + (x * x);
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
float x_ = type_convert<float>(x);
return type_convert<T>(y_ + (x_ * x_));
}
// For combining partial results
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
CK_TILE_HOST_DEVICE constexpr T combine_partial_results(const T& partial1,
const T& partial2) const
{
return partial1 + partial2; // Just add the partial sums, don't square again
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T combine_partial_results(T& partial1, T& partial2) const
{
float partial1_ = type_convert<float>(partial1);
float partial2_ = type_convert<float>(partial2);
return type_convert<T>(partial1_ + partial2_);
}
static constexpr bool requires_special_combine = true;
};
struct Max
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::min();
@@ -65,18 +102,24 @@ struct Max
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, x);
}
static constexpr bool requires_special_combine = false;
};
struct AbsMax
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::min();
@@ -84,11 +127,15 @@ struct AbsMax
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, abs(x));
}
static constexpr bool requires_special_combine = false;
};
} // namespace ReduceOp

View File

@@ -30,4 +30,82 @@ reference_reduce(const HostTensor<XDataType>& x_m_n, HostTensor<YDataType>& y_m,
make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
// Generic reference reduce for arbitrary dimensions
template <
typename XDataType,
typename ComputeDataType,
typename YDataType,
typename ReduceOp,
typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to keep
typename ReduceDims> // Expected type: ck_tile::sequence<...> containing dimension indices to
// reduce
CK_TILE_HOST void reference_reduce(const HostTensor<XDataType>& x_tensor,
HostTensor<YDataType>& y_tensor,
ReduceOp reduce_op,
KeptDim kept_dim,
ReduceDims reduce_dims)
{
const auto& x_lengths = x_tensor.mDesc.get_lengths();
// Calculate total kept elements (product of all kept dimension lengths)
index_t total_kept_elements = 1;
static_for<0, kept_dim.size(), 1>{}(
[&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
// Calculate total reduce elements (product of all reduce dimension lengths)
index_t total_reduce_elements = 1;
static_for<0, reduce_dims.size(), 1>{}(
[&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
auto f = [&](auto linear_kept_idx) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
// Convert linear kept index to multi-dimensional kept indices
std::vector<index_t> kept_indices(kept_dim.size());
index_t temp_kept = linear_kept_idx;
static_for<0, kept_dim.size(), 1>{}([&](auto i) {
constexpr auto dim_idx = kept_dim.size() - 1 - i;
constexpr auto dim = kept_dim.at(dim_idx);
const auto len = x_lengths[dim];
kept_indices[dim_idx] = temp_kept % len;
temp_kept /= len;
});
for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
{
// Convert linear reduce index to multi-dimensional reduce indices
std::vector<index_t> reduce_indices(reduce_dims.size());
index_t temp_reduce = reduce_idx;
static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
constexpr auto dim_idx = reduce_dims.size() - 1 - i;
constexpr auto dim = reduce_dims.at(dim_idx);
const auto len = x_lengths[dim];
reduce_indices[dim_idx] = temp_reduce % len;
temp_reduce /= len;
});
// Build full input tensor indices by combining kept and reduce indices
std::vector<std::size_t> full_indices(x_lengths.size(), 0);
static_for<0, kept_dim.size(), 1>{}(
[&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
static_for<0, reduce_dims.size(), 1>{}(
[&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
// Access input tensor element
const auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
v_acc = reduce_op(v_acc, v_a);
}
// Calculate output tensor index using kept indices
// The output tensor has the same structure as the kept dimensions
std::vector<std::size_t> y_indices(kept_dim.size());
static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
y_tensor(y_indices) = type_convert<YDataType>(v_acc);
};
make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -5,8 +5,11 @@
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp"

View File

@@ -7,20 +7,55 @@
namespace ck_tile {
// BlockReduce2d implements a hierarchical 2D reduction operator that reduces data along the second
// dimension using a user-specified reduction function.
//
// The reduction is performed in a three-stage hierarchical approach:
//
// STAGE 1: Thread-level reduction (BlockReduce2d)
// ===============================================
// - Each thread processes multiple elements from the input tensor within its assigned data
// partition
// - Reduction is performed locally within each thread by iterating over assigned elements
// - ReducePacksPerXDim controls how many elements sweep_tile processes in one iteration per
// dimension
// (e.g., {1,1} = 1 element at a time from each dimension, {2,4} = 2 from dim0, 4 from dim1)
// - Results are accumulated into a thread-local output tensor stored in registers
// - The output tensor distribution is derived from the input tensor's distribution using
// make_reduce_tile_distribution_encoding() to handle dimension reduction
//
// STAGE 2: Warp-level reduction (BlockReduce2dSync)
// ================================================
// - Performs inter-thread reduction within each warp
// - Uses warp shuffle operations to exchange data between threads in the same warp
// - Implements a tree-reduction pattern with power-of-2 stages
// - Only reduces along dimensions that map to lane IDs within the warp
//
// STAGE 3: Cross-warp reduction (BlockReduce2dCrossWarpSync)
// ========================================================
// - Performs reduction across multiple warps within the same thread block
// - Uses shared memory (LDS) to facilitate data exchange between warps
// - Each warp's lane-0 thread stores its partial results to shared memory
// - All threads participate in loading and reducing data from shared memory
// - Implements block-level synchronization to ensure memory consistency
// BlockReduce2d: Thread-level reduction (Stage 1)
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2d
{
// in-thread reduction
// Thread-level reduction implementation
using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType;
CK_TILE_DEVICE constexpr BlockReduce2d() {}
template <typename XDistributedTensor_,
typename YDistributedTensor_,
typename ReduceFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
template <
typename XDistributedTensor_,
typename YDistributedTensor_,
typename ReduceFunc,
typename ReducePacksPerXDim =
uniform_sequence_gen_t<2, 1>> // {1,1} = process 1 element at a time from each dimension
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
YDistributedTensor_& y_tensor,
const ReduceFunc& reduce_func,
@@ -33,6 +68,7 @@ struct BlockReduce2d
y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
},
ReducePacksPerXDim{});
#if 0
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
@@ -75,6 +111,8 @@ struct BlockReduce2d
return tensor;
}
// uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
// e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
template <typename XDistributedTensor_,
typename ReduceFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
@@ -91,6 +129,7 @@ struct BlockReduce2d
}
};
// BlockReduce2dSync: Warp-level reduction (Stage 2)
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2dSync
{
@@ -145,8 +184,15 @@ struct BlockReduce2dSync
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
// reduce
v_local = reduce_func(v_local, v_remote);
// For reduce, use combine_partial_results for operations that require it
if constexpr(ReduceFunc::requires_special_combine)
{
v_local = reduce_func.combine_partial_results(v_local, v_remote);
}
else
{
v_local = reduce_func(v_local, v_remote);
}
});
}
});
@@ -157,6 +203,7 @@ struct BlockReduce2dSync
}
};
// BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2dCrossWarpSync
{
@@ -263,8 +310,15 @@ struct BlockReduce2dCrossWarpSync
constexpr auto i_1 = number<i_1_n1 + 1>{};
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
// reduce
v_local = reduce_func(v_local, v_remote);
// For reduce, use combine_partial_results for operations that require it
if constexpr(ReduceFunc::requires_special_combine)
{
v_local = reduce_func.combine_partial_results(v_local, v_remote);
}
else
{
v_local = reduce_func(v_local, v_remote);
}
});
y_tensor.get_thread_buffer()(i_0) = v_local;

View File

@@ -0,0 +1,219 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp"
// Reduce2d Kernel:
// =======================================
// This kernel implements a 2D reduction operation that reduces data along the second dimension
// of a matrix. The reduction is performed in multiple hierarchical stages.
namespace ck_tile {
template <typename Problem_, typename Policy_ = Reduce2dDefaultPolicy>
struct Reduce
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
private:
// Helper function to calculate optimal vector size for input tensor
template <typename InputShape, typename ReduceDims>
static constexpr index_t CalculateInputVectorSize()
{
using S = typename Problem::BlockShape;
constexpr index_t memory_vector_size = 16 / sizeof(XDataType);
constexpr index_t thread_tile_vector_size = S::ThreadTile_N;
// Check if innermost reduce dimension is the last dimension (stride 1).
constexpr auto innermost_reduce_dim = ReduceDims{}.at(number<ReduceDims{}.size() - 1>{});
constexpr bool is_innermost_contiguous = (innermost_reduce_dim == InputShape{}.size() - 1);
// If innermost reduce dimension is not the last dim (not contiguous), limit vectorization
constexpr index_t stride_based_vector_size =
is_innermost_contiguous ? ck_tile::min(memory_vector_size, thread_tile_vector_size) : 1;
return stride_based_vector_size;
}
// Helper function to calculate optimal vector size for output tensor
static constexpr index_t CalculateOutputVectorSize()
{
using S = typename Problem::BlockShape;
constexpr index_t memory_vector_size = 16 / sizeof(YDataType);
constexpr index_t thread_tile_vector_size = S::ThreadTile_M;
constexpr index_t vector_size = ck_tile::min(memory_vector_size, thread_tile_vector_size);
return vector_size;
}
public:
template <typename InputShape, typename InputStrides, typename KeptDim, typename ReduceDims>
CK_TILE_DEVICE void operator()(const XDataType* p_x,
YDataType* p_y,
InputShape input_shape,
InputStrides input_strides,
KeptDim kept_dim,
ReduceDims reduce_dims) const
{
using S = typename Problem::BlockShape;
const auto iM = get_block_id() * S::Block_M;
static_assert(kept_dim.size() + reduce_dims.size() == InputShape::size(),
"Size of kept dimensions + reduced dimensions must equal input tensor rank");
// Extract lengths based on kept and reduced dimensions
const auto kept_lens = [&]() {
return generate_tuple([&](auto I) { return input_shape.at(number<kept_dim.at(I)>{}); },
number<kept_dim.size()>{});
}();
const auto reduce_lens = [&]() {
return generate_tuple(
[&](auto I) { return input_shape.at(number<reduce_dims.at(I)>{}); },
number<reduce_dims.size()>{});
}();
const auto kept_merge_transform = make_merge_transform(kept_lens);
const auto reduce_merge_transform = make_merge_transform(reduce_lens);
auto reduce_func = typename Problem::ReduceOp{};
const XDataType custom_padding_value =
type_convert<XDataType>(reduce_func.template GetIdentityValue<ComputeDataType>());
// Calculate optimal vector size for input tensor
constexpr auto x_tensor_vector_size = CalculateInputVectorSize<InputShape, ReduceDims>();
// Create input tensor view with custom padding value
auto desc = make_naive_tensor_descriptor(
input_shape, input_strides, number<x_tensor_vector_size>{}, number<1>{});
// Create buffer view with custom padding value
auto buffer_view = make_buffer_view<address_space_enum::global>(
p_x, desc.get_element_space_size(), custom_padding_value);
// Create tensor view with custom padding
const auto x_tensor = tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
const auto transformed_x_tensor = pad_tensor_view(
transform_tensor_view(x_tensor,
make_tuple(kept_merge_transform, reduce_merge_transform),
make_tuple(kept_dim, reduce_dims),
make_tuple(sequence<0>{}, sequence<1>{})),
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
sequence<0, 1>{});
// Calculate strides for output tensor based on its own dimensions
const auto kept_strides = [&]() {
return generate_tuple(
[&](auto I) {
// Calculate stride for dimension I as product of all following dimensions
index_t stride = 1;
static_for<I + 1, kept_dim.size(), 1>{}(
[&](auto J) { stride *= kept_lens.at(number<J>{}); });
return stride;
},
number<kept_dim.size()>{});
}();
// Calculate optimal vector size for output tensor
constexpr auto y_tensor_vector_size = CalculateOutputVectorSize();
const auto y_m = make_naive_tensor_view<address_space_enum::global>(
p_y, kept_lens, kept_strides, number<y_tensor_vector_size>{}, number<1>{});
// Transform output tensor to 1D merged view
// This creates a view compatible with the 2D reduction pattern
const auto y_merged = transform_tensor_view(
y_m,
make_tuple(kept_merge_transform),
make_tuple(typename arithmetic_sequence_gen<0, kept_dim.size(), 1>::type{}),
make_tuple(sequence<0>{}));
auto x_window = make_tile_window(transformed_x_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{iM, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto y_window = make_tile_window(y_merged, make_tuple(number<S::Block_M>{}), {iM});
__shared__ char smem[Policy::template GetSmemSize<Problem>()];
// Get the merged dimension size from the transformed tensor
const auto merged_reduce_len =
transformed_x_tensor.get_tensor_descriptor().get_lengths().at(number<1>{});
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(merged_reduce_len, S::Block_N));
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using XTensorType = decltype(load_tile(x_window));
auto y_compute = block_reduce2d.template MakeYBlockTile<XTensorType>();
set_tile(y_compute, reduce_func.template GetIdentityValue<ComputeDataType>());
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
block_reduce2d(x, y_compute, reduce_func);
move_tile_window(x_window, {0, S::Block_N});
}
block_reduce2d_sync(y_compute, reduce_func);
block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func);
store_tile(y_window, cast_tile<YDataType>(y_compute));
}
/// @brief Validates if the given arguments are supported by the 2D reduction kernel.
///
/// @param y_continous_dim Size of the continuous dimension of the output tensor.
/// Must be a multiple of ThreadTile_N for proper thread mapping.
///
/// @param input_strides The stride configuration of the input tensor.
/// The last stride must be 1 to ensure contiguous memory access
/// and enable efficient vectorized loads.
///
/// @return true if the arguments are supported, false otherwise.
/// Error messages are logged when CK_TILE_LOGGING is enabled.
///
/// @note Requirements:
/// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution)
/// - input_strides[-1] == 1 (for contiguous memory access)
CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, auto input_strides)
{
using S = typename Problem::BlockShape;
if(y_continous_dim % S::ThreadTile_N != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Total reduction size should be a multiple of ThreadTile_N!");
}
return false;
}
if(input_strides.at(number<input_strides.size() - 1>{}) != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Input tensor's last stride must be 1 to support correct vector access!");
}
return false;
}
return true;
}
};
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -9,7 +9,7 @@
namespace ck_tile {
struct BlockReduce2dDefaultPolicy
struct Reduce2dDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
@@ -18,8 +18,9 @@ struct BlockReduce2dDefaultPolicy
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 1, 2, 2>,

View File

@@ -0,0 +1,27 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename XDataType_,
typename ComputeDataType_,
typename YDataType_,
typename BlockShape_,
typename ReduceOp_>
struct Reduce2dProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using ReduceOp = ReduceOp_;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
};
} // namespace ck_tile

View File

@@ -0,0 +1,37 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockWarps, // num warps along seq<M, N>
typename BlockTile, // block size, seq<M, N>
typename WarpTile, // warp size, seq<M, N>
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
struct Reduce2dShape
{
static constexpr index_t Block_M = BlockTile::at(number<0>{});
static constexpr index_t Block_N = BlockTile::at(number<1>{});
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{});
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N;
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
};
} // namespace ck_tile