From d73a2287f36160f71c9062a48af686a268ecbeab Mon Sep 17 00:00:00 2001 From: Yashvardhan Agarwal Date: Wed, 17 Dec 2025 21:46:08 +0200 Subject: [PATCH] [ck_tile] refactor reduce kernel (#3257) * refactor reduce kernel - Rename Reduce kernel as per convention - Move kept_dim and reduce_dims from runtime to compile-time parameters - Update Reduce2dProblem template to include KeptDim, ReduceDims, and Rank - Remove IsSupportedArgument validation function as it's unnecessary. Not using the GuaranteedLastDimensionVectorStride while making tensor view or descriptor which removes the bounds enforced earlier. We still calculate and use vector size. - Update reduce example to demonstrate NCHW->NHW reduction with non-contiguous support - Update tests Kernel now handles both contiguous and non-contiguous memory layout. * fix compile errors [ROCm/composable_kernel commit: ea10a782036688cdc2a91266f675125bf1c5c59d] --- .../03_gemm/gemm_splitk_two_stage_reduce.cpp | 32 +++---- example/ck_tile/05_reduce/reduce.cpp | 66 +++++++------- .../ops/reduce/kernel/reduce2d_kernel.hpp | 91 +++++-------------- .../ops/reduce/pipeline/reduce2d_problem.hpp | 7 ++ test/ck_tile/reduce/test_reduce2d.cpp | 23 ++--- 5 files changed, 89 insertions(+), 130 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index abad4ab5c4..c06dc457c9 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -286,7 +286,6 @@ template float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s) { - const ck_tile::index_t reduce_dim_size = args.k_batch; // Number of partial results to reduce // Calculate output size based on the final output tensor dimensions const ck_tile::index_t output_size = args.M * args.N; @@ -303,27 +302,28 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config constexpr auto reduce_dims = ck_tile::sequence<0>{}; // Reduce k_batch dimension using ReduceOp = ck_tile::ReduceOp::Add; - using BlockWarps = ck_tile::sequence<4, 1>; - using BlockTile = ck_tile::sequence<128, 128>; - using WarpTile = ck_tile::sequence<32, 128>; - using ThreadTile = ck_tile::sequence<8, 8>; + using BlockWarps = ck_tile::sequence<1, 1>; + using BlockTile = ck_tile::sequence<256, 1>; + using WarpTile = ck_tile::sequence<256, 1>; + using ThreadTile = ck_tile::sequence<1, 1>; constexpr ck_tile::index_t kBlockPerCu = 1; ck_tile::index_t kGridSize = (output_size + BlockTile::at(ck_tile::number<0>{}) - 1) / BlockTile::at(ck_tile::number<0>{}); - using Shape = ck_tile::Reduce2dShape; - using Problem = - ck_tile::Reduce2dProblem; - using Kernel = ck_tile::Reduce; + using Shape = ck_tile::Reduce2dShape; + using Problem = ck_tile::Reduce2dProblem; + using Kernel = ck_tile::ReduceKernel; const ck_tile::index_t kBlockSize = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(reduce_dim_size, workspace_strides)) - { - throw std::runtime_error("Wrong! Reduction arguments not supported!\n"); - } - if(s.log_level_ > 0) { std::cout << "Stage 2 - Launching Reduction kernel" << '\n' @@ -343,9 +343,7 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config static_cast(args.e_ptr), // workspace input static_cast(args.final_output_ptr), // final output workspace_shape, - workspace_strides, - kept_dim, - reduce_dims)); + workspace_strides)); return ave_time; } diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index 677065c78d..f6742e613e 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -9,14 +9,14 @@ auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("n", "32", "n dimension") - .insert("h", "7", "h dimension") - .insert("w", "7", "w dimension") - .insert("c", "512", "c dimension") + arg_parser.insert("n", "16", "n dimension") + .insert("h", "64", "h dimension") + .insert("w", "32", "w dimension") + .insert("c", "960", "c dimension") .insert("v", "1", "cpu validation or not") .insert("prec", "fp16", "precision") - .insert("warmup", "5", "cold iter") - .insert("repeat", "20", "hot iter") + .insert("warmup", "20", "cold iter") + .insert("repeat", "100", "hot iter") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "reduce.json", "json file name to dump results"); @@ -47,12 +47,12 @@ bool run(const ck_tile::ArgParser& arg_parser) strides[3] = 1; // Define reduction specification: - constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep - constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce + constexpr auto kept_dim = ck_tile::sequence<1, 2, 3>{}; // Which dimension to keep + constexpr auto reduce_dims = ck_tile::sequence<0>{}; // Which dimensions to reduce ck_tile::HostTensor x_host(problem_shape, strides); - ck_tile::HostTensor y_host_ref({N, C}, {C, 1}); - ck_tile::HostTensor y_host_dev({N, C}, {C, 1}); + ck_tile::HostTensor y_host_ref({H, W, C}, {W * C, C, 1}); + ck_tile::HostTensor y_host_dev({H, W, C}, {W * C, C, 1}); ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); @@ -62,40 +62,40 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); using ReduceOp = ck_tile::ReduceOp::Add; - using BlockWarps = ck_tile::sequence<4, 1>; - using BlockTile = ck_tile::sequence<128, 128>; - using WarpTile = ck_tile::sequence<32, 128>; - using Vector = ck_tile::sequence<8, 8>; + using BlockWarps = ck_tile::sequence<1, 1>; + using BlockTile = ck_tile::sequence<256, 1>; + using WarpTile = ck_tile::sequence<256, 1>; + using ThreadTile = ck_tile::sequence<1, 1>; // cross warp-reduce // using BlockWarps = ck_tile::sequence<2, 2>; // using BlockTile = ck_tile::sequence<2, 1024>; // using WarpTile = ck_tile::sequence<1, 512>; - // using Vector = ck_tile::sequence<1, 8>; + // using ThreadTile = ck_tile::sequence<1, 8>; constexpr ck_tile::index_t kBlockPerCu = 1; - ck_tile::index_t kept_dim_len_prod = N * C; + ck_tile::index_t kept_dim_len_prod = H * W * C; ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) / BlockTile::at(ck_tile::number<0>{}); std::cout << "grid size " << kGridSize << std::endl; - using Shape = ck_tile::Reduce2dShape; - using Porblem = - ck_tile::Reduce2dProblem; + using Shape = ck_tile::Reduce2dShape; + using Porblem = ck_tile::Reduce2dProblem; - using Kernel = ck_tile::Reduce; + using Kernel = ck_tile::ReduceKernel; const ck_tile::index_t kBlockSize = Kernel::BlockSize(); // Create input tensor shape and strides auto input_shape = ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]); auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]); - if(!Kernel::IsSupportedArgument( - C, input_strides)) // output tensor's continuous dimension and input strides - { - throw std::runtime_error("Wrong! Arguments not supported!\n"); - } - float ave_time = launch_kernel( ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, ck_tile::make_kernel(Kernel{}, @@ -105,11 +105,9 @@ bool run(const ck_tile::ArgParser& arg_parser) static_cast(x_buf.GetDeviceBuffer()), static_cast(y_buf.GetDeviceBuffer()), input_shape, - input_strides, - kept_dim, - reduce_dims)); + input_strides)); - std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C; + std::size_t num_btype = sizeof(XDataType) * N * H * W * C + sizeof(YDataType) * H * W * C; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -149,8 +147,8 @@ int main(int argc, char* argv[]) { return run(arg_parser) ? 0 : -2; } - // else if(data_type == "bf16") - // { - // return run(arg_parser) ? 0 : -2; - // } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } } diff --git a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp index 1503b2b18b..dddfa26a53 100644 --- a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp +++ b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp @@ -16,7 +16,7 @@ namespace ck_tile { template -struct Reduce +struct ReduceKernel { using Problem = ck_tile::remove_cvref_t; using Policy = ck_tile::remove_cvref_t; @@ -33,7 +33,7 @@ struct Reduce private: // Helper function to calculate optimal vector size for input tensor - template + template static constexpr index_t CalculateInputVectorSize() { using S = typename Problem::BlockShape; @@ -41,8 +41,8 @@ struct Reduce constexpr index_t thread_tile_vector_size = S::ThreadTile_N; // Check if innermost reduce dimension is the last dimension (stride 1). - constexpr auto innermost_reduce_dim = ReduceDims{}.at(number{}); - constexpr bool is_innermost_contiguous = (innermost_reduce_dim == InputShape{}.size() - 1); + constexpr index_t innermost_reduce_dim = ReduceDims::at(number{}); + constexpr bool is_innermost_contiguous = (innermost_reduce_dim == Rank - 1); // If innermost reduce dimension is not the last dim (not contiguous), limit vectorization constexpr index_t stride_based_vector_size = @@ -63,29 +63,28 @@ struct Reduce } public: - template + template CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, InputShape input_shape, - InputStrides input_strides, - KeptDim kept_dim, - ReduceDims reduce_dims) const + InputStrides input_strides) const { using S = typename Problem::BlockShape; const auto iM = get_block_id() * S::Block_M; - static_assert(kept_dim.size() + reduce_dims.size() == InputShape::size(), + static_assert(Problem::KeptDim::size() + Problem::ReduceDims::size() == Problem::Rank, "Size of kept dimensions + reduced dimensions must equal input tensor rank"); // Extract lengths based on kept and reduced dimensions const auto kept_lens = [&]() { - return generate_tuple([&](auto I) { return input_shape.at(number{}); }, - number{}); + return generate_tuple( + [&](auto I) { return input_shape.at(number{}); }, + number{}); }(); const auto reduce_lens = [&]() { return generate_tuple( - [&](auto I) { return input_shape.at(number{}); }, - number{}); + [&](auto I) { return input_shape.at(number{}); }, + number{}); }(); const auto kept_merge_transform = make_merge_transform(kept_lens); @@ -96,11 +95,13 @@ struct Reduce type_convert(reduce_func.template GetIdentityValue()); // Calculate optimal vector size for input tensor - constexpr auto x_tensor_vector_size = CalculateInputVectorSize(); + constexpr auto x_tensor_vector_size = CalculateInputVectorSize(); // Create input tensor view with custom padding value auto desc = make_naive_tensor_descriptor( - input_shape, input_strides, number{}, number<1>{}); + input_shape, input_strides, number{}); // Create buffer view with custom padding value auto buffer_view = make_buffer_view( @@ -109,10 +110,11 @@ struct Reduce // Create tensor view with custom padding const auto x_tensor = tensor_view{buffer_view, desc}; const auto transformed_x_tensor = pad_tensor_view( - transform_tensor_view(x_tensor, - make_tuple(kept_merge_transform, reduce_merge_transform), - make_tuple(kept_dim, reduce_dims), - make_tuple(sequence<0>{}, sequence<1>{})), + transform_tensor_view( + x_tensor, + make_tuple(kept_merge_transform, reduce_merge_transform), + make_tuple(typename Problem::KeptDim{}, typename Problem::ReduceDims{}), + make_tuple(sequence<0>{}, sequence<1>{})), make_tuple(number{}, number{}), sequence<0, 1>{}); @@ -122,25 +124,25 @@ struct Reduce [&](auto I) { // Calculate stride for dimension I as product of all following dimensions index_t stride = 1; - static_for{}( + static_for{}( [&](auto J) { stride *= kept_lens.at(number{}); }); return stride; }, - number{}); + number{}); }(); // Calculate optimal vector size for output tensor constexpr auto y_tensor_vector_size = CalculateOutputVectorSize(); const auto y_m = make_naive_tensor_view( - p_y, kept_lens, kept_strides, number{}, number<1>{}); + p_y, kept_lens, kept_strides, number{}); // Transform output tensor to 1D merged view // This creates a view compatible with the 2D reduction pattern const auto y_merged = transform_tensor_view( y_m, make_tuple(kept_merge_transform), - make_tuple(typename arithmetic_sequence_gen<0, kept_dim.size(), 1>::type{}), + make_tuple(typename arithmetic_sequence_gen<0, Problem::KeptDim::size(), 1>::type{}), make_tuple(sequence<0>{})); auto x_window = make_tile_window(transformed_x_tensor, @@ -179,49 +181,6 @@ struct Reduce store_tile(y_window, cast_tile(y_compute)); } - - /// @brief Validates if the given arguments are supported by the 2D reduction kernel. - /// - /// @param y_continous_dim Size of the continuous dimension of the output tensor. - /// Must be a multiple of ThreadTile_N for proper thread mapping. - /// - /// @param input_strides The stride configuration of the input tensor. - /// The last stride must be 1 to ensure contiguous memory access - /// and enable efficient vectorized loads. - /// - /// @return true if the arguments are supported, false otherwise. - /// Error messages are logged when CK_TILE_LOGGING is enabled. - /// - /// @note Requirements: - /// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution) - /// - input_strides[-1] == 1 (for contiguous memory access) - template - CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, - InputStrides input_strides) - { - using S = typename Problem::BlockShape; - - if(y_continous_dim % S::ThreadTile_N != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Total reduction size should be a multiple of ThreadTile_N!"); - } - return false; - } - - if(input_strides.at(number{}) != 1) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Input tensor's last stride must be 1 to support correct vector access!"); - } - return false; - } - - return true; - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp b/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp index 1298bff274..83049b832e 100644 --- a/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp +++ b/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp @@ -12,6 +12,9 @@ template struct Reduce2dProblem { @@ -20,7 +23,11 @@ struct Reduce2dProblem using YDataType = remove_cvref_t; using BlockShape = remove_cvref_t; using ReduceOp = ReduceOp_; + using KeptDim = remove_cvref_t; + using ReduceDims = remove_cvref_t; + static constexpr index_t Rank = Rank_; + static constexpr index_t NumReduceDim = ReduceDims::size(); static constexpr bool kOutputIndex = OutputIndex_; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; diff --git a/test/ck_tile/reduce/test_reduce2d.cpp b/test/ck_tile/reduce/test_reduce2d.cpp index 5513729f44..93ce3fd565 100644 --- a/test/ck_tile/reduce/test_reduce2d.cpp +++ b/test/ck_tile/reduce/test_reduce2d.cpp @@ -53,10 +53,16 @@ class TestCkTileReduce : public ::testing::Test d_y_mem.ToDevice(h_y.data()); // Initialize device output buffer // Problem and kernel setup - using Problem = ck_tile:: - Reduce2dProblem; + using Problem = ck_tile::Reduce2dProblem; - using Kernel = ck_tile::Reduce; + using Kernel = ck_tile::ReduceKernel; // Launch configuration const ck_tile::index_t kBlockSize = Kernel::BlockSize(); @@ -75,13 +81,6 @@ class TestCkTileReduce : public ::testing::Test auto input_shape_tuple = make_shape_tuple.template operator()(input_shape); auto input_strides_tuple = make_shape_tuple.template operator()(input_strides); - if(!Kernel::IsSupportedArgument( - output_shape[output_shape.size() - 1], - input_strides_tuple)) // output tensor's continuous dimension - { - throw std::runtime_error("Wrong! Arguments not supported!\n"); - } - ck_tile::launch_kernel( ck_tile::stream_config{nullptr, false, 0}, ck_tile::make_kernel(Kernel{}, @@ -91,9 +90,7 @@ class TestCkTileReduce : public ::testing::Test static_cast(d_x_mem.GetDeviceBuffer()), static_cast(d_y_mem.GetDeviceBuffer()), input_shape_tuple, - input_strides_tuple, - kept_dims, - reduce_dims)); + input_strides_tuple)); // Get results back d_y_mem.FromDevice(h_y.data());