From b6f0e98da6a96ebc1ff6775b4cd88d04746c1a2f Mon Sep 17 00:00:00 2001 From: Yashvardhan Agarwal Date: Wed, 6 Aug 2025 16:36:59 +0300 Subject: [PATCH] General 2D Reduction Kernel (#2535) * General 2D Reduction Kernel * Move the reduction kernel from the example * Split the code and add the necessary policy, problem, shape files as per ck_tile convention * Add/modify the headers * Modified the example to work with the 'new' kernel * Added tests for the kernel * N-D refernce reduce * Added support for N-D input with transform to 2D * Added padding to support various input sized tensors * Bug fix in the thread buffer constructor * Some comments to explain the reduce2d block kernel * comments resolution * clang-format * comments resolution * clang-format * clang-format * comments resolution * clang-format [ROCm/composable_kernel commit: 4750b293fe0abfa44a32181742a48b1dfec468f7] --- example/ck_tile/05_reduce/reduce.cpp | 63 ++- example/ck_tile/05_reduce/reduce.hpp | 164 -------- .../ck_tile/core/container/thread_buffer.hpp | 6 +- .../ck_tile/core/utility/reduce_operator.hpp | 57 ++- .../host/reference/reference_reduce.hpp | 78 ++++ include/ck_tile/ops/reduce.hpp | 5 +- .../ops/reduce/block/block_reduce2d.hpp | 72 +++- .../ops/reduce/kernel/reduce2d_kernel.hpp | 219 +++++++++++ .../reduce2d_default_policy.hpp} | 9 +- .../ops/reduce/pipeline/reduce2d_problem.hpp | 27 ++ .../ops/reduce/pipeline/reduce2d_shape.hpp | 37 ++ test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/reduce/CMakeLists.txt | 7 + test/ck_tile/reduce/test_reduce2d.cpp | 359 ++++++++++++++++++ 14 files changed, 905 insertions(+), 199 deletions(-) delete mode 100644 example/ck_tile/05_reduce/reduce.hpp create mode 100644 include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp rename include/ck_tile/ops/reduce/{block/block_reduce2d_default_policy.hpp => pipeline/reduce2d_default_policy.hpp} (89%) create mode 100644 include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp create mode 100644 include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp create mode 100644 test/ck_tile/reduce/CMakeLists.txt create mode 100644 test/ck_tile/reduce/test_reduce2d.cpp diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index 602661f779..cf816caa88 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -1,16 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck_tile/host.hpp" -#include "reduce.hpp" +#include "ck_tile/ops/reduce.hpp" #include auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3328", "m dimension") - .insert("n", "4096", "n dimension") + arg_parser.insert("n", "32", "n dimension") + .insert("h", "7", "h dimension") + .insert("w", "7", "w dimension") + .insert("c", "512", "c dimension") .insert("v", "1", "cpu validation or not") .insert("prec", "fp16", "precision") - .insert("warmup", "5", "cold iter") - .insert("repeat", "20", "hot iter"); + .insert("warmup", "0", "cold iter") + .insert("repeat", "1", "hot iter"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -23,15 +28,28 @@ bool run(const ck_tile::ArgParser& arg_parser) using ComputeDataType = float; using YDataType = DataType; - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t H = arg_parser.get_int("h"); + ck_tile::index_t W = arg_parser.get_int("w"); + ck_tile::index_t C = arg_parser.get_int("c"); int do_validation = arg_parser.get_int("v"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); - ck_tile::HostTensor x_host({m, n}); - ck_tile::HostTensor y_host_ref({m}); - ck_tile::HostTensor y_host_dev({m}); + std::vector problem_shape = {N, H, W, C}; + std::vector strides(4); + strides[0] = H * W * C; + strides[1] = W * C; + strides[2] = C; + strides[3] = 1; + + // Define reduction specification: + constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce + + ck_tile::HostTensor x_host(problem_shape, strides); + ck_tile::HostTensor y_host_ref({N, C}, {C, 1}); + ck_tile::HostTensor y_host_dev({N, C}, {C, 1}); ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); @@ -54,7 +72,9 @@ bool run(const ck_tile::ArgParser& arg_parser) constexpr ck_tile::index_t kBlockSize = 256; constexpr ck_tile::index_t kBlockPerCu = 1; - ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); + ck_tile::index_t kept_dim_len_prod = N * C; + ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) / + BlockTile::at(ck_tile::number<0>{}); std::cout << "grid size " << kGridSize << std::endl; using Shape = ck_tile::Reduce2dShape; @@ -63,6 +83,17 @@ bool run(const ck_tile::ArgParser& arg_parser) using Kernel = ck_tile::Reduce; + // Create input tensor shape and strides + auto input_shape = + ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]); + auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]); + + if(!Kernel::IsSupportedArgument( + C, input_strides)) // output tensor's continuous dimension and input strides + { + throw std::runtime_error("Wrong! Arguments not supported!\n"); + } + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, ck_tile::make_kernel( Kernel{}, @@ -71,10 +102,12 @@ bool run(const ck_tile::ArgParser& arg_parser) 0, static_cast(x_buf.GetDeviceBuffer()), static_cast(y_buf.GetDeviceBuffer()), - m, - n)); + input_shape, + input_strides, + kept_dim, + reduce_dims)); - std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m; + std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -86,7 +119,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { // reference ck_tile::reference_reduce( - x_host, y_host_ref, ReduceOp{}); + x_host, y_host_ref, ReduceOp{}, kept_dim, reduce_dims); y_buf.FromDevice(y_host_dev.mData.data()); pass = ck_tile::check_err(y_host_dev, y_host_ref); diff --git a/example/ck_tile/05_reduce/reduce.hpp b/example/ck_tile/05_reduce/reduce.hpp deleted file mode 100644 index 6fbb0b4274..0000000000 --- a/example/ck_tile/05_reduce/reduce.hpp +++ /dev/null @@ -1,164 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" -#include "ck_tile/ops/reduce/block/block_reduce.hpp" -#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp" - -namespace ck_tile { - -template - typename BlockTile, // block size, seq - typename WarpTile, // warp size, seq - typename Vector> // contiguous pixels(vector size) along seq -struct Reduce2dShape -{ - static constexpr index_t Block_M = BlockTile::at(number<0>{}); - static constexpr index_t Block_N = BlockTile::at(number<1>{}); - - static constexpr index_t Warp_M = WarpTile::at(number<0>{}); - static constexpr index_t Warp_N = WarpTile::at(number<1>{}); - - static constexpr index_t Vector_M = Vector::at(number<0>{}); - static constexpr index_t Vector_N = Vector::at(number<1>{}); - - static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{}); - static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{}); - - static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; - static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; - - static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); - static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); - - static constexpr index_t BlockSize = - ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); -}; - -template -struct Reduce2dProblem -{ - using XDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - using YDataType = remove_cvref_t; - using BlockShape = remove_cvref_t; - using ReduceOp = ReduceOp_; - - static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; - static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; -}; - -template -struct Reduce -{ - using Problem = ck_tile::remove_cvref_t; - using Policy = ck_tile::remove_cvref_t; - - using XDataType = ck_tile::remove_cvref_t; - using ComputeDataType = ck_tile::remove_cvref_t; - using YDataType = ck_tile::remove_cvref_t; - -#if 0 - CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N) - const - { - using S = typename Problem::BlockShape; - - const auto x_m_n = make_naive_tensor_view( - p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); - - const auto y_m = make_naive_tensor_view_packed( - p_y, make_tuple(M), number<1>{}); - - const auto iM = get_block_id() * S::Block_M; - - auto x_window = make_tile_window(x_m_n, - make_tuple(number{}, number{}), - {iM, 0}, - Policy::template MakeXBlockTileDistribution()); - - auto y_window = make_tile_window(y_m, make_tuple(number{}), {iM}); - - const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; }; - - const XDataType reduce_init_value = 0; - - constexpr auto reduce_dims = sequence<1>{}; - - auto y_compute = decltype(block_tile_reduce( - load_tile(x_window), reduce_dims, f_reduce, reduce_init_value)){}; - - set_tile(y_compute, reduce_init_value); - - index_t num_n_tile_iteration = - __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); - - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) - { - const auto x = load_tile(x_window); - block_tile_reduce(y_compute, x, reduce_dims, f_reduce); - move_tile_window(x_window, {0, S::Block_N}); - } - - block_tile_reduce_sync(y_compute, f_reduce); - - store_tile(y_window, cast_tile(y_compute)); - } -#else - CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N) const - { - using S = typename Problem::BlockShape; - - const auto x_m_n = make_naive_tensor_view( - p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); - - const auto y_m = make_naive_tensor_view_packed( - p_y, make_tuple(M), number<1>{}); - - const auto iM = get_block_id() * S::Block_M; - - auto x_window = make_tile_window(x_m_n, - make_tuple(number{}, number{}), - {iM, 0}, - Policy::template MakeXBlockTileDistribution()); - - auto y_window = make_tile_window(y_m, make_tuple(number{}), {iM}); - - __shared__ char smem[Policy::template GetSmemSize()]; - - index_t num_n_tile_iteration = - __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); - - auto reduce_func = typename Problem::ReduceOp{}; - auto block_reduce2d = Policy::template GetBlockReduce2d(); - auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); - auto block_reduce2d_cross_warp_sync = - Policy::template GetBlockReduce2dCrossWarpSync(); - - using XTensorType = decltype(load_tile(x_window)); - auto y_compute = block_reduce2d.template MakeYBlockTile(); - set_tile(y_compute, reduce_func.template GetIdentityValue()); - - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) - { - const auto x = load_tile(x_window); - block_reduce2d(x, y_compute, reduce_func); - move_tile_window(x_window, {0, S::Block_N}); - } - - block_reduce2d_sync(y_compute, reduce_func); - block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func); - - store_tile(y_window, cast_tile(y_compute)); - } -#endif -}; - -} // namespace ck_tile diff --git a/include/ck_tile/core/container/thread_buffer.hpp b/include/ck_tile/core/container/thread_buffer.hpp index 77c46e1b8c..d67581e7d2 100644 --- a/include/ck_tile/core/container/thread_buffer.hpp +++ b/include/ck_tile/core/container/thread_buffer.hpp @@ -42,7 +42,11 @@ struct thread_buffer { // TODO: this ctor can't ignore CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {} - CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {} + CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{} { + static_for<0, N, 1>{}( + [&](auto i) { data[i] = o; } + ); + } CK_TILE_HOST_DEVICE static constexpr auto size() { return N; } CK_TILE_HOST_DEVICE auto & get() {return data; } diff --git a/include/ck_tile/core/utility/reduce_operator.hpp b/include/ck_tile/core/utility/reduce_operator.hpp index 8b15d187fe..2d7ac78b06 100644 --- a/include/ck_tile/core/utility/reduce_operator.hpp +++ b/include/ck_tile/core/utility/reduce_operator.hpp @@ -26,7 +26,8 @@ struct Add } template || std::is_same_v>> + typename = std::enable_if_t || std::is_same_v || + std::is_same_v || std::is_same_v>> CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const { float y_ = type_convert(y); @@ -34,6 +35,8 @@ struct Add return type_convert(y_ + x_); } + + static constexpr bool requires_special_combine = false; }; struct SquareAdd @@ -51,13 +54,47 @@ struct SquareAdd { return y + (x * x); } + + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const + { + float y_ = type_convert(y); + float x_ = type_convert(x); + return type_convert(y_ + (x_ * x_)); + } + + // For combining partial results + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T combine_partial_results(const T& partial1, + const T& partial2) const + { + return partial1 + partial2; // Just add the partial sums, don't square again + } + + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T combine_partial_results(T& partial1, T& partial2) const + { + float partial1_ = type_convert(partial1); + float partial2_ = type_convert(partial2); + return type_convert(partial1_ + partial2_); + } + + static constexpr bool requires_special_combine = true; }; struct Max { template || std::is_same_v || - std::is_same_v || std::is_same_v>> + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v>> CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() { return numeric::min(); @@ -65,18 +102,24 @@ struct Max template || std::is_same_v || - std::is_same_v || std::is_same_v>> + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v>> CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const { return max(y, x); } + + static constexpr bool requires_special_combine = false; }; struct AbsMax { template || std::is_same_v || - std::is_same_v || std::is_same_v>> + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v>> CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() { return numeric::min(); @@ -84,11 +127,15 @@ struct AbsMax template || std::is_same_v || - std::is_same_v || std::is_same_v>> + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v>> CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const { return max(y, abs(x)); } + + static constexpr bool requires_special_combine = false; }; } // namespace ReduceOp diff --git a/include/ck_tile/host/reference/reference_reduce.hpp b/include/ck_tile/host/reference/reference_reduce.hpp index 8f8aa23670..9952b7b009 100644 --- a/include/ck_tile/host/reference/reference_reduce.hpp +++ b/include/ck_tile/host/reference/reference_reduce.hpp @@ -30,4 +30,82 @@ reference_reduce(const HostTensor& x_m_n, HostTensor& y_m, make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); } + +// Generic reference reduce for arbitrary dimensions +template < + typename XDataType, + typename ComputeDataType, + typename YDataType, + typename ReduceOp, + typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to keep + typename ReduceDims> // Expected type: ck_tile::sequence<...> containing dimension indices to + // reduce +CK_TILE_HOST void reference_reduce(const HostTensor& x_tensor, + HostTensor& y_tensor, + ReduceOp reduce_op, + KeptDim kept_dim, + ReduceDims reduce_dims) +{ + const auto& x_lengths = x_tensor.mDesc.get_lengths(); + + // Calculate total kept elements (product of all kept dimension lengths) + index_t total_kept_elements = 1; + static_for<0, kept_dim.size(), 1>{}( + [&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; }); + + // Calculate total reduce elements (product of all reduce dimension lengths) + index_t total_reduce_elements = 1; + static_for<0, reduce_dims.size(), 1>{}( + [&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; }); + + auto f = [&](auto linear_kept_idx) { + ComputeDataType v_acc = reduce_op.template GetIdentityValue(); + + // Convert linear kept index to multi-dimensional kept indices + std::vector kept_indices(kept_dim.size()); + index_t temp_kept = linear_kept_idx; + static_for<0, kept_dim.size(), 1>{}([&](auto i) { + constexpr auto dim_idx = kept_dim.size() - 1 - i; + constexpr auto dim = kept_dim.at(dim_idx); + const auto len = x_lengths[dim]; + kept_indices[dim_idx] = temp_kept % len; + temp_kept /= len; + }); + + for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx) + { + // Convert linear reduce index to multi-dimensional reduce indices + std::vector reduce_indices(reduce_dims.size()); + index_t temp_reduce = reduce_idx; + static_for<0, reduce_dims.size(), 1>{}([&](auto i) { + constexpr auto dim_idx = reduce_dims.size() - 1 - i; + constexpr auto dim = reduce_dims.at(dim_idx); + const auto len = x_lengths[dim]; + reduce_indices[dim_idx] = temp_reduce % len; + temp_reduce /= len; + }); + + // Build full input tensor indices by combining kept and reduce indices + std::vector full_indices(x_lengths.size(), 0); + static_for<0, kept_dim.size(), 1>{}( + [&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; }); + static_for<0, reduce_dims.size(), 1>{}( + [&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; }); + + // Access input tensor element + const auto v_a = type_convert(x_tensor(full_indices)); + + v_acc = reduce_op(v_acc, v_a); + } + + // Calculate output tensor index using kept indices + // The output tensor has the same structure as the kept dimensions + std::vector y_indices(kept_dim.size()); + static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; }); + + y_tensor(y_indices) = type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency()); +} } // namespace ck_tile diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index 80ead84e85..042e0b98c2 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -5,8 +5,11 @@ #include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d.hpp" -#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" +#include "ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp" +#include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp" +#include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" +#include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index 62c9944bd2..849fa6c252 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -7,20 +7,55 @@ namespace ck_tile { +// BlockReduce2d implements a hierarchical 2D reduction operator that reduces data along the second +// dimension using a user-specified reduction function. +// +// The reduction is performed in a three-stage hierarchical approach: +// +// STAGE 1: Thread-level reduction (BlockReduce2d) +// =============================================== +// - Each thread processes multiple elements from the input tensor within its assigned data +// partition +// - Reduction is performed locally within each thread by iterating over assigned elements +// - ReducePacksPerXDim controls how many elements sweep_tile processes in one iteration per +// dimension +// (e.g., {1,1} = 1 element at a time from each dimension, {2,4} = 2 from dim0, 4 from dim1) +// - Results are accumulated into a thread-local output tensor stored in registers +// - The output tensor distribution is derived from the input tensor's distribution using +// make_reduce_tile_distribution_encoding() to handle dimension reduction +// +// STAGE 2: Warp-level reduction (BlockReduce2dSync) +// ================================================ +// - Performs inter-thread reduction within each warp +// - Uses warp shuffle operations to exchange data between threads in the same warp +// - Implements a tree-reduction pattern with power-of-2 stages +// - Only reduces along dimensions that map to lane IDs within the warp +// +// STAGE 3: Cross-warp reduction (BlockReduce2dCrossWarpSync) +// ======================================================== +// - Performs reduction across multiple warps within the same thread block +// - Uses shared memory (LDS) to facilitate data exchange between warps +// - Each warp's lane-0 thread stores its partial results to shared memory +// - All threads participate in loading and reducing data from shared memory +// - Implements block-level synchronization to ensure memory consistency + +// BlockReduce2d: Thread-level reduction (Stage 1) template struct BlockReduce2d { - // in-thread reduction + // Thread-level reduction implementation using Problem = remove_cvref_t; using XDataType = typename Problem::XDataType; using ComputeDataType = typename Problem::ComputeDataType; CK_TILE_DEVICE constexpr BlockReduce2d() {} - template > + template < + typename XDistributedTensor_, + typename YDistributedTensor_, + typename ReduceFunc, + typename ReducePacksPerXDim = + uniform_sequence_gen_t<2, 1>> // {1,1} = process 1 element at a time from each dimension CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func, @@ -33,6 +68,7 @@ struct BlockReduce2d y_tensor(idx_0), ck_tile::type_convert(x_tensor[idx_])...); }, ReducePacksPerXDim{}); + #if 0 constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; @@ -75,6 +111,8 @@ struct BlockReduce2d return tensor; } + // uniform_sequence_gen_t generates sequence of NSize elements filled with Value + // e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4} template > @@ -91,6 +129,7 @@ struct BlockReduce2d } }; +// BlockReduce2dSync: Warp-level reduction (Stage 2) template struct BlockReduce2dSync { @@ -145,8 +184,15 @@ struct BlockReduce2dSync // pull data from remote lane const auto v_remote = warp_shuffle(v_local, src_lane); - // reduce - v_local = reduce_func(v_local, v_remote); + // For reduce, use combine_partial_results for operations that require it + if constexpr(ReduceFunc::requires_special_combine) + { + v_local = reduce_func.combine_partial_results(v_local, v_remote); + } + else + { + v_local = reduce_func(v_local, v_remote); + } }); } }); @@ -157,6 +203,7 @@ struct BlockReduce2dSync } }; +// BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3) template struct BlockReduce2dCrossWarpSync { @@ -263,8 +310,15 @@ struct BlockReduce2dCrossWarpSync constexpr auto i_1 = number{}; const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; - // reduce - v_local = reduce_func(v_local, v_remote); + // For reduce, use combine_partial_results for operations that require it + if constexpr(ReduceFunc::requires_special_combine) + { + v_local = reduce_func.combine_partial_results(v_local, v_remote); + } + else + { + v_local = reduce_func(v_local, v_remote); + } }); y_tensor.get_thread_buffer()(i_0) = v_local; diff --git a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp new file mode 100644 index 0000000000..f65487ea6e --- /dev/null +++ b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp @@ -0,0 +1,219 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" +#include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp" + +// Reduce2d Kernel: +// ======================================= +// This kernel implements a 2D reduction operation that reduces data along the second dimension +// of a matrix. The reduction is performed in multiple hierarchical stages. + +namespace ck_tile { + +template +struct Reduce +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + + private: + // Helper function to calculate optimal vector size for input tensor + template + static constexpr index_t CalculateInputVectorSize() + { + using S = typename Problem::BlockShape; + constexpr index_t memory_vector_size = 16 / sizeof(XDataType); + constexpr index_t thread_tile_vector_size = S::ThreadTile_N; + + // Check if innermost reduce dimension is the last dimension (stride 1). + constexpr auto innermost_reduce_dim = ReduceDims{}.at(number{}); + constexpr bool is_innermost_contiguous = (innermost_reduce_dim == InputShape{}.size() - 1); + + // If innermost reduce dimension is not the last dim (not contiguous), limit vectorization + constexpr index_t stride_based_vector_size = + is_innermost_contiguous ? ck_tile::min(memory_vector_size, thread_tile_vector_size) : 1; + + return stride_based_vector_size; + } + + // Helper function to calculate optimal vector size for output tensor + static constexpr index_t CalculateOutputVectorSize() + { + using S = typename Problem::BlockShape; + constexpr index_t memory_vector_size = 16 / sizeof(YDataType); + constexpr index_t thread_tile_vector_size = S::ThreadTile_M; + constexpr index_t vector_size = ck_tile::min(memory_vector_size, thread_tile_vector_size); + + return vector_size; + } + + public: + template + CK_TILE_DEVICE void operator()(const XDataType* p_x, + YDataType* p_y, + InputShape input_shape, + InputStrides input_strides, + KeptDim kept_dim, + ReduceDims reduce_dims) const + { + using S = typename Problem::BlockShape; + const auto iM = get_block_id() * S::Block_M; + + static_assert(kept_dim.size() + reduce_dims.size() == InputShape::size(), + "Size of kept dimensions + reduced dimensions must equal input tensor rank"); + + // Extract lengths based on kept and reduced dimensions + const auto kept_lens = [&]() { + return generate_tuple([&](auto I) { return input_shape.at(number{}); }, + number{}); + }(); + const auto reduce_lens = [&]() { + return generate_tuple( + [&](auto I) { return input_shape.at(number{}); }, + number{}); + }(); + + const auto kept_merge_transform = make_merge_transform(kept_lens); + const auto reduce_merge_transform = make_merge_transform(reduce_lens); + + auto reduce_func = typename Problem::ReduceOp{}; + const XDataType custom_padding_value = + type_convert(reduce_func.template GetIdentityValue()); + + // Calculate optimal vector size for input tensor + constexpr auto x_tensor_vector_size = CalculateInputVectorSize(); + + // Create input tensor view with custom padding value + auto desc = make_naive_tensor_descriptor( + input_shape, input_strides, number{}, number<1>{}); + + // Create buffer view with custom padding value + auto buffer_view = make_buffer_view( + p_x, desc.get_element_space_size(), custom_padding_value); + + // Create tensor view with custom padding + const auto x_tensor = tensor_view{buffer_view, desc}; + const auto transformed_x_tensor = pad_tensor_view( + transform_tensor_view(x_tensor, + make_tuple(kept_merge_transform, reduce_merge_transform), + make_tuple(kept_dim, reduce_dims), + make_tuple(sequence<0>{}, sequence<1>{})), + make_tuple(number{}, number{}), + sequence<0, 1>{}); + + // Calculate strides for output tensor based on its own dimensions + const auto kept_strides = [&]() { + return generate_tuple( + [&](auto I) { + // Calculate stride for dimension I as product of all following dimensions + index_t stride = 1; + static_for{}( + [&](auto J) { stride *= kept_lens.at(number{}); }); + return stride; + }, + number{}); + }(); + + // Calculate optimal vector size for output tensor + constexpr auto y_tensor_vector_size = CalculateOutputVectorSize(); + + const auto y_m = make_naive_tensor_view( + p_y, kept_lens, kept_strides, number{}, number<1>{}); + + // Transform output tensor to 1D merged view + // This creates a view compatible with the 2D reduction pattern + const auto y_merged = transform_tensor_view( + y_m, + make_tuple(kept_merge_transform), + make_tuple(typename arithmetic_sequence_gen<0, kept_dim.size(), 1>::type{}), + make_tuple(sequence<0>{})); + + auto x_window = make_tile_window(transformed_x_tensor, + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeXBlockTileDistribution()); + + auto y_window = make_tile_window(y_merged, make_tuple(number{}), {iM}); + + __shared__ char smem[Policy::template GetSmemSize()]; + + // Get the merged dimension size from the transformed tensor + const auto merged_reduce_len = + transformed_x_tensor.get_tensor_descriptor().get_lengths().at(number<1>{}); + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(merged_reduce_len, S::Block_N)); + + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + using XTensorType = decltype(load_tile(x_window)); + auto y_compute = block_reduce2d.template MakeYBlockTile(); + set_tile(y_compute, reduce_func.template GetIdentityValue()); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + block_reduce2d(x, y_compute, reduce_func); + move_tile_window(x_window, {0, S::Block_N}); + } + + block_reduce2d_sync(y_compute, reduce_func); + block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func); + + store_tile(y_window, cast_tile(y_compute)); + } + + /// @brief Validates if the given arguments are supported by the 2D reduction kernel. + /// + /// @param y_continous_dim Size of the continuous dimension of the output tensor. + /// Must be a multiple of ThreadTile_N for proper thread mapping. + /// + /// @param input_strides The stride configuration of the input tensor. + /// The last stride must be 1 to ensure contiguous memory access + /// and enable efficient vectorized loads. + /// + /// @return true if the arguments are supported, false otherwise. + /// Error messages are logged when CK_TILE_LOGGING is enabled. + /// + /// @note Requirements: + /// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution) + /// - input_strides[-1] == 1 (for contiguous memory access) + CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, auto input_strides) + { + using S = typename Problem::BlockShape; + + if(y_continous_dim % S::ThreadTile_N != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Total reduction size should be a multiple of ThreadTile_N!"); + } + return false; + } + + if(input_strides.at(number{}) != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Input tensor's last stride must be 1 to support correct vector access!"); + } + return false; + } + + return true; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp b/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp similarity index 89% rename from include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp rename to include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp index 3c547242d5..27bb4bcdcb 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp +++ b/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,7 +9,7 @@ namespace ck_tile { -struct BlockReduce2dDefaultPolicy +struct Reduce2dDefaultPolicy { template CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() @@ -18,8 +18,9 @@ struct BlockReduce2dDefaultPolicy return make_static_tile_distribution( tile_distribution_encoding< sequence<>, - tuple, - sequence>, + tuple< + sequence, + sequence>, tuple, sequence<1, 2>>, tuple, sequence<2, 2>>, sequence<1, 1, 2, 2>, diff --git a/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp b/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp new file mode 100644 index 0000000000..67fdec9286 --- /dev/null +++ b/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct Reduce2dProblem +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + using ReduceOp = ReduceOp_; + + static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; + static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp b/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp new file mode 100644 index 0000000000..31eb1f2f4f --- /dev/null +++ b/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template + typename BlockTile, // block size, seq + typename WarpTile, // warp size, seq + typename ThreadTile> // contiguous pixels(vector size) along seq +struct Reduce2dShape +{ + static constexpr index_t Block_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_N = BlockTile::at(number<1>{}); + + static constexpr index_t Warp_M = WarpTile::at(number<0>{}); + static constexpr index_t Warp_N = WarpTile::at(number<1>{}); + + static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{}); + static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{}); + + static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{}); + static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{}); + + static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N; + + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + static constexpr index_t BlockSize = + ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); +}; +} // namespace ck_tile diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 42605f2513..9a1df56208 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -21,3 +21,4 @@ add_subdirectory(add_rmsnorm2d_rdquant) # add_subdirectory(layernorm2d) # add_subdirectory(rmsnorm2d) add_subdirectory(gemm_block_scale) +add_subdirectory(reduce) \ No newline at end of file diff --git a/test/ck_tile/reduce/CMakeLists.txt b/test/ck_tile/reduce/CMakeLists.txt new file mode 100644 index 0000000000..052669e20a --- /dev/null +++ b/test/ck_tile/reduce/CMakeLists.txt @@ -0,0 +1,7 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_ck_tile_reduce2d test_reduce2d.cpp) + if(result EQUAL 0) + target_link_libraries(test_ck_tile_reduce2d PRIVATE utility) + endif() +endif() + diff --git a/test/ck_tile/reduce/test_reduce2d.cpp b/test/ck_tile/reduce/test_reduce2d.cpp new file mode 100644 index 0000000000..4ce0b56ef3 --- /dev/null +++ b/test/ck_tile/reduce/test_reduce2d.cpp @@ -0,0 +1,359 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +template +class TestCkTileReduce : public ::testing::Test +{ + protected: + using XDataType = std::tuple_element_t<0, Tuple>; + using ComputeDataType = std::tuple_element_t<1, Tuple>; + using YDataType = std::tuple_element_t<2, Tuple>; + using ReduceOpType = std::tuple_element_t<3, Tuple>; + using BlockWarps_ = std::tuple_element_t<4, Tuple>; + using BlockTile_ = std::tuple_element_t<5, Tuple>; + using WarpTile_ = std::tuple_element_t<6, Tuple>; + using ThreadTile_ = std::tuple_element_t<7, Tuple>; + + using TestReduce2dShape = + ck_tile::Reduce2dShape; + + template + void RunGenericTest(const std::vector& input_shape, + const std::vector& input_strides, + const std::vector& output_shape, + const std::vector& output_strides, + ck_tile::index_t kept_dim_len_prod, + ck_tile::index_t total_reduce_elements, + KeptDimSeq kept_dims, + ReduceDimSeq reduce_dims) + { + ck_tile::HostTensor h_x(input_shape, input_strides); + ck_tile::HostTensor h_y(output_shape, output_strides); + ck_tile::HostTensor h_y_ref(output_shape, output_strides); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(h_x); + h_y.SetZero(); + h_y_ref.SetZero(); + + ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_y_mem(h_y.get_element_space_size_in_bytes()); + + d_x_mem.ToDevice(h_x.data()); + d_y_mem.ToDevice(h_y.data()); // Initialize device output buffer + + // Problem and kernel setup + using Problem = ck_tile:: + Reduce2dProblem; + + using Kernel = ck_tile::Reduce; + + // Launch configuration + constexpr ck_tile::index_t kBlockSize = 256; + constexpr ck_tile::index_t kBlockPerCu = 1; + + ck_tile::index_t kGridSize = + (kept_dim_len_prod + TestReduce2dShape::Block_M - 1) / TestReduce2dShape::Block_M; + + // Generic helper to create tuple from vector based on compile-time size + auto make_shape_tuple = [](const std::vector& vec) { + return [&vec](std::index_sequence) { + return ck_tile::make_tuple(vec[I]...); + }(std::make_index_sequence{}); + }; + + auto input_shape_tuple = make_shape_tuple.template operator()(input_shape); + auto input_strides_tuple = make_shape_tuple.template operator()(input_strides); + + if(!Kernel::IsSupportedArgument( + output_shape[output_shape.size() - 1], + input_strides_tuple)) // output tensor's continuous dimension + { + throw std::runtime_error("Wrong! Arguments not supported!\n"); + } + + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(d_x_mem.GetDeviceBuffer()), + static_cast(d_y_mem.GetDeviceBuffer()), + input_shape_tuple, + input_strides_tuple, + kept_dims, + reduce_dims)); + + // Get results back + d_y_mem.FromDevice(h_y.data()); + + // Reference computation + ck_tile::reference_reduce( + h_x, h_y_ref, ReduceOpType{}, kept_dims, reduce_dims); + + // Calculate proper error thresholds based on data types and number of accumulations + const auto rtol = ck_tile::get_relative_threshold( + total_reduce_elements); + const auto atol = ck_tile::get_absolute_threshold( + 5.0f, total_reduce_elements); + + bool result = + ck_tile::check_err(h_y, h_y_ref, "Error: Incorrect reduce results!", rtol, atol); + EXPECT_TRUE(result); + } + + // Convenience functions for specific dimensional patterns + void RunTest2D_KeepDim0_ReduceDim1(ck_tile::index_t dim0, ck_tile::index_t dim1) + { + constexpr auto kept_dims = ck_tile::sequence<0>{}; + constexpr auto reduce_dims = ck_tile::sequence<1>{}; + + // Input shape and strides + std::vector input_shape = {dim0, dim1}; + std::vector input_strides = {dim1, 1}; + + // Output shape and strides (keep dim0) + std::vector output_shape = {dim0}; + std::vector output_strides = {1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = dim0; + ck_tile::index_t total_reduce_elements = dim1; + + RunGenericTest<2>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest3D_KeepDim0_ReduceDim12(ck_tile::index_t dim0, + ck_tile::index_t dim1, + ck_tile::index_t dim2) + { + constexpr auto kept_dims = ck_tile::sequence<0>{}; + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; + + // Input shape and strides + std::vector input_shape = {dim0, dim1, dim2}; + std::vector input_strides = {dim1 * dim2, dim2, 1}; + + // Output shape and strides (keep dim0) + std::vector output_shape = {dim0}; + std::vector output_strides = {1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = dim0; // product of kept dimensions + ck_tile::index_t total_reduce_elements = dim1 * dim2; // product of reduced dimensions + + RunGenericTest<3>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest3D_KeepDim01_ReduceDim2(ck_tile::index_t dim0, + ck_tile::index_t dim1, + ck_tile::index_t dim2) + { + constexpr auto kept_dims = ck_tile::sequence<0, 1>{}; + constexpr auto reduce_dims = ck_tile::sequence<2>{}; + + // Input shape and strides + std::vector input_shape = {dim0, dim1, dim2}; + std::vector input_strides = {dim1 * dim2, dim2, 1}; + + // Output shape and strides (keep dim0) + std::vector output_shape = {dim0, dim1}; + std::vector output_strides = {dim1, 1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = dim0 * dim1; // product of kept dimensions + ck_tile::index_t total_reduce_elements = dim2; // product of reduced dimensions + + RunGenericTest<3>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest4D_KeepDim01_ReduceDim23(ck_tile::index_t N, + ck_tile::index_t C, + ck_tile::index_t H, + ck_tile::index_t W) + { + constexpr auto kept_dims = ck_tile::sequence<0, 1>{}; + constexpr auto reduce_dims = ck_tile::sequence<2, 3>{}; + + // Input shape and strides + std::vector input_shape = {N, C, H, W}; + std::vector input_strides = {C * H * W, H * W, W, 1}; + + // Output shape and strides (keep dim0, dim1) + std::vector output_shape = {N, C}; + std::vector output_strides = {C, 1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = N * C; // product of kept dimensions + ck_tile::index_t total_reduce_elements = H * W; // product of reduced dimensions + + RunGenericTest<4>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest4D_KeepDim03_ReduceDim12(ck_tile::index_t N, + ck_tile::index_t H, + ck_tile::index_t W, + ck_tile::index_t C) + { + constexpr auto kept_dims = ck_tile::sequence<0, 3>{}; + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; + + // Input shape and strides + std::vector input_shape = {N, H, W, C}; + std::vector input_strides = {H * W * C, W * C, C, 1}; + + // Output shape and strides (keep dim0, dim1) + std::vector output_shape = {N, C}; + std::vector output_strides = {C, 1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = N * C; // product of kept dimensions + ck_tile::index_t total_reduce_elements = H * W; // product of reduced dimensions + + RunGenericTest<4>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } +}; + +// Shape parameters for different test configurations +using Shape1_BlockWarps = ck_tile::sequence<4, 1>; +using Shape1_BlockTile = ck_tile::sequence<128, 128>; +using Shape1_WarpTile = ck_tile::sequence<32, 128>; +using Shape1_ThreadTile = ck_tile::sequence<8, 8>; + +using Shape2_BlockWarps = ck_tile::sequence<2, 2>; // Cross-warp reduction test +using Shape2_BlockTile = ck_tile::sequence<2, 1024>; +using Shape2_WarpTile = ck_tile::sequence<1, 512>; +using Shape2_ThreadTile = ck_tile::sequence<1, 8>; + +// Test configurations for different data types and operations +using TestConfig_F32_Add = std::tuple; + +using TestConfig_F16_Add = std::tuple; + +using TestConfig_F32_CrossWarp = std::tuple; + +using TestConfig_F32_Max = std::tuple; + +using TestConfig_F32_SquareAdd = std::tuple; + +using TestTypes = ::testing::Types; + +TYPED_TEST_SUITE(TestCkTileReduce, TestTypes); + +// 2D Tests - Keep dim0, reduce dim1 +TYPED_TEST(TestCkTileReduce, Test2D_KeepDim0_ReduceDim1_64x32) +{ + this->RunTest2D_KeepDim0_ReduceDim1(64, 32); +} + +TYPED_TEST(TestCkTileReduce, Test2D_KeepDim0_ReduceDim1_1024x512) +{ + this->RunTest2D_KeepDim0_ReduceDim1(1024, 512); +} + +// 3D Tests - Keep dim0, reduce dim1,2 +TYPED_TEST(TestCkTileReduce, Test3D_KeepDim0_ReduceDim12_128x128x1) +{ + this->RunTest3D_KeepDim0_ReduceDim12(128, 128, 8); +} +// 3D Tests - Keep dim0,1, reduce dim1 +TYPED_TEST(TestCkTileReduce, Test3D_KeepDim01_ReduceDim2_512x1024x16) +{ + this->RunTest3D_KeepDim01_ReduceDim2(512, 1024, 16); +} + +// 4D Tests - Keep dim0,1, reduce dim2,3 (NCHW -> NC) +TYPED_TEST(TestCkTileReduce, Test4D_KeepDim01_ReduceDim23_32x256x16x16) +{ + this->RunTest4D_KeepDim01_ReduceDim23(32, 256, 16, 16); +} +// 4D Tests - Keep dim0,3, reduce dim1,2 (NHWC -> NC) +TYPED_TEST(TestCkTileReduce, Test4D_KeepDim03_ReduceDim12_16x32x32x128) +{ + this->RunTest4D_KeepDim03_ReduceDim12(16, 32, 32, 128); +}