From bc9e9df38f68ab7a776468c01863b571df739ed4 Mon Sep 17 00:00:00 2001 From: Yashvardhan Agarwal Date: Thu, 9 Oct 2025 17:13:26 +0300 Subject: [PATCH] [CK_TILE] Pooling FWD (Lwpck 3683) (#2956) * Pooling 2D/3D with refernce * Tests & cleanup - added test for ppoling - cleanup - removed 2d example * Comment resolution - README added - example target name rectified - appropriate arg description and comments added * clang-format * appropriate blocksize calc * modifications for future indexing addition - instead of transforming views we now transform the descriptors, so that the same descriptor can be re-used for index tensor in the future * some basic fixes * comment resolutions * comment resolutions --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit: 7b6451b68efada41ede3e1effb2c0aa9cf1ca314] --- CHANGELOG.md | 1 + example/ck_tile/36_pooling/CMakeLists.txt | 8 + example/ck_tile/36_pooling/README.md | 42 ++ example/ck_tile/36_pooling/pool3d.cpp | 188 +++++++ example/ck_tile/CMakeLists.txt | 1 + .../ck_tile/host/reference/reference_pool.hpp | 147 ++++++ include/ck_tile/ops/pool.hpp | 11 + .../ops/pooling/kernel/pool_kernel.hpp | 496 ++++++++++++++++++ .../pooling/pipeline/pool_default_policy.hpp | 80 +++ .../ops/pooling/pipeline/pool_problem.hpp | 33 ++ .../ops/pooling/pipeline/pool_shape.hpp | 57 ++ test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/pooling/CMakeLists.txt | 3 + test/ck_tile/pooling/test_pooling.cpp | 249 +++++++++ 14 files changed, 1317 insertions(+) create mode 100644 example/ck_tile/36_pooling/CMakeLists.txt create mode 100644 example/ck_tile/36_pooling/README.md create mode 100644 example/ck_tile/36_pooling/pool3d.cpp create mode 100644 include/ck_tile/host/reference/reference_pool.hpp create mode 100644 include/ck_tile/ops/pool.hpp create mode 100644 include/ck_tile/ops/pooling/kernel/pool_kernel.hpp create mode 100644 include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp create mode 100644 include/ck_tile/ops/pooling/pipeline/pool_problem.hpp create mode 100644 include/ck_tile/ops/pooling/pipeline/pool_shape.hpp create mode 100644 test/ck_tile/pooling/CMakeLists.txt create mode 100644 test/ck_tile/pooling/test_pooling.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 9aadc3dc54..a8fe7b4afb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM. * Added support for f32 to FMHA (fwd/bwd). * Added tensor-wise quantization for CK_TILE GEMM. +* Added pooling kernel in CK_TILE ### Optimized diff --git a/example/ck_tile/36_pooling/CMakeLists.txt b/example/ck_tile/36_pooling/CMakeLists.txt new file mode 100644 index 0000000000..425a8c83ba --- /dev/null +++ b/example/ck_tile/36_pooling/CMakeLists.txt @@ -0,0 +1,8 @@ +set(EXAMPLE_POOL_3D "tile_example_pool3d") +message(DEBUG "adding example ${EXAMPLE_POOL_3D}") + +add_executable(${EXAMPLE_POOL_3D} EXCLUDE_FROM_ALL pool3d.cpp) +target_include_directories(${EXAMPLE_POOL_3D} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + +target_compile_options(${EXAMPLE_POOL_3D} PRIVATE ${EXAMPLE_POOL_COMPILE_OPTIONS}) + diff --git a/example/ck_tile/36_pooling/README.md b/example/ck_tile/36_pooling/README.md new file mode 100644 index 0000000000..ab49b57095 --- /dev/null +++ b/example/ck_tile/36_pooling/README.md @@ -0,0 +1,42 @@ +# Pooling Operator + +This folder contains example for the pooling operator using ck_tile tile-programming implementation. Currently the pooling kernel only supports 2D and 3D pooling. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +../script/cmake-ck-dev.sh ../ +# The 3D pooling example +make tile_example_pool3d -j`nproc` +``` +This will result in an executable `build/bin/tile_example_pool3d` + +## example +``` +args: + -N batch size (default:2) + -D depth dimension (default:30) + -H height dimension (default:30) + -W width dimension (default:30) + -C channel dimension (default:32) + -Z pooling window depth (default:2) + -Y pooling window height (default:2) + -X pooling window width (default:2) + -Sz window stride depth (default:2) + -Sy window stride height (default:2) + -Sx window stride width (default:2) + -Dz window dilation depth (default:1) + -Dy window dilation height (default:1) + -Dx window dilation width (default:1) + -LeftPz left padding depth (default:1) + -LeftPy left padding height (default:1) + -LeftPx left padding width (default:1) + -RightPz right padding depth (default:1) + -RightPy right padding height (default:1) + -RightPx right padding width (default:1) + -v 0: No validation, 1: CPU validation (default:1) + -warmup number of iterations before benchmark (default:0) + -repeat number of iterations to benchmark (default:1) +``` diff --git a/example/ck_tile/36_pooling/pool3d.cpp b/example/ck_tile/36_pooling/pool3d.cpp new file mode 100644 index 0000000000..bdfa1d99b3 --- /dev/null +++ b/example/ck_tile/36_pooling/pool3d.cpp @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/pool.hpp" +#include "ck_tile/host/reference/reference_pool.hpp" +#include + +// Parse command-line arguments for 3D pooling example +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("N", "2", "N dimension") + .insert("H", "30", "H dimension") + .insert("W", "30", "W dimension") + .insert("C", "32", "C dimension") + .insert("D", "30", "D dimension") + .insert("Z", "2", "Z dimension") + .insert("Y", "2", "Y dimension") + .insert("X", "2", "X dimension") + .insert("Sz", "2", "window stride d") + .insert("Sy", "2", "window stride h") + .insert("Sx", "2", "window stride w") + .insert("Dz", "1", "window dilation d") + .insert("Dy", "1", "window dilation h") + .insert("Dx", "1", "window dilation w") + .insert("LeftPz", "1", "left padding d") + .insert("LeftPy", "1", "left padding h") + .insert("LeftPx", "1", "left padding w") + .insert("RightPz", "1", "right padding d") + .insert("RightPy", "1", "right padding h") + .insert("RightPx", "1", "right padding w") + .insert("v", "1", "cpu validation or not") + .insert("warmup", "0", "cold iter") + .insert("repeat", "1", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + + const ck_tile::index_t N = arg_parser.get_int("N"); + const ck_tile::index_t H = arg_parser.get_int("H"); + const ck_tile::index_t W = arg_parser.get_int("W"); + const ck_tile::index_t C = arg_parser.get_int("C"); + const ck_tile::index_t D = arg_parser.get_int("D"); + + const ck_tile::index_t Z = arg_parser.get_int("Z"); + const ck_tile::index_t Y = arg_parser.get_int("Y"); + const ck_tile::index_t X = arg_parser.get_int("X"); + + const ck_tile::index_t Sz = arg_parser.get_int("Sz"); + const ck_tile::index_t Sy = arg_parser.get_int("Sy"); + const ck_tile::index_t Sx = arg_parser.get_int("Sx"); + + const ck_tile::index_t Dz = arg_parser.get_int("Dz"); + const ck_tile::index_t Dy = arg_parser.get_int("Dy"); + const ck_tile::index_t Dx = arg_parser.get_int("Dx"); + + const ck_tile::index_t LeftPz = arg_parser.get_int("LeftPz"); + const ck_tile::index_t LeftPy = arg_parser.get_int("LeftPy"); + const ck_tile::index_t LeftPx = arg_parser.get_int("LeftPx"); + const ck_tile::index_t RightPz = arg_parser.get_int("RightPz"); + const ck_tile::index_t RightPy = arg_parser.get_int("RightPy"); + const ck_tile::index_t RightPx = arg_parser.get_int("RightPx"); + + const ck_tile::index_t Zs = (Z - 1) * Dz + 1; + const ck_tile::index_t Ys = (Y - 1) * Dy + 1; + const ck_tile::index_t Xs = (X - 1) * Dx + 1; + + const ck_tile::index_t Do = (D + LeftPz + RightPz - Zs) / Sz + 1; + const ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1; + const ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1; + + printf("Input parameters:\n"); + printf("N: %d, D: %d, H: %d, W: %d, C: %d\n", N, D, H, W, C); + printf("Window Z: %d, Y: %d, X: %d, Stride Z: %d, Y: %d, X: %d\n", Z, Y, X, Sz, Sy, Sx); + printf("Output Do: %d, Ho: %d, Wo: %d\n", Do, Ho, Wo); + + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + // Shapes / strides / parameters (NDHWC) + const auto input_shape = ck_tile::make_tuple(N, D, H, W, C); + const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C); + const auto input_strides = ck_tile::make_tuple(D * H * W * C, H * W * C, W * C, C, 1); + const auto output_strides = ck_tile::make_tuple(Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1); + const auto window_spatial_lengths = ck_tile::make_tuple(Z, Y, X); + const auto window_strides = ck_tile::make_tuple(Sz, Sy, Sx); + const auto window_dilations = ck_tile::make_tuple(Dz, Dy, Dx); + const auto input_left_pads = ck_tile::make_tuple(LeftPz, LeftPy, LeftPx); + const auto input_right_pads = ck_tile::make_tuple(RightPz, RightPy, RightPx); + + ck_tile::HostTensor in({N, D, H, W, C}, {D * H * W * C, H * W * C, W * C, C, 1}); + ck_tile::HostTensor out({N, Do, Ho, Wo, C}, + {Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1}); + ck_tile::HostTensor out_ref({N, Do, Ho, Wo, C}, + {Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1}); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(in); + + ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes()); + ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes()); + + in_buf.ToDevice(in.data()); + + using ReduceOp = ck_tile::ReduceOp::Max; + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using ThreadTile = ck_tile::sequence<8, 8>; + + using Shape = ck_tile::PoolShape; + using Problem = ck_tile::PoolProblem; + using Kernel = ck_tile::PoolKernel; + + constexpr ck_tile::index_t kBlockPerCu = 1; + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + + auto host_args = ck_tile::PoolHostArgs{ + static_cast(in_buf.GetDeviceBuffer()), + static_cast(out_buf.GetDeviceBuffer()), + input_shape, + output_shape, + input_strides, + output_strides, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads}; + + auto kernel_args = Kernel::MakeKernelArgs(host_args); + + const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args); + std::cout << "grid size " << kGridSize << std::endl; + + // Validate kernel can handle the given configuration + if(!Kernel::IsSupportedArgument(kernel_args)) + { + throw std::runtime_error("ERROR: Kernel arguments are not supported! \n"); + } + + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, kGridSize, kBlockSize, 0, kernel_args)); + + std::size_t num_btype = + sizeof(InDataType) * N * D * H * W * C + sizeof(OutDataType) * N * Do * Ho * Wo * C; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + ck_tile::reference_pool3d( + in, out_ref, kernel_args, ReduceOp{}); + out_buf.FromDevice(out.mData.data()); + pass = ck_tile::check_err(out, out_ref); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + return run(arg_parser) ? 0 : -2; +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 931f2d315f..7a8ae065db 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -23,6 +23,7 @@ add_subdirectory(20_grouped_convolution) add_subdirectory(21_elementwise) add_subdirectory(22_gemm_multi_abd) add_subdirectory(35_batched_transpose) +add_subdirectory(36_pooling) add_subdirectory(38_block_scale_gemm) add_subdirectory(39_copy) add_subdirectory(40_streamk_gemm) diff --git a/include/ck_tile/host/reference/reference_pool.hpp b/include/ck_tile/host/reference/reference_pool.hpp new file mode 100644 index 0000000000..1b3e45bce8 --- /dev/null +++ b/include/ck_tile/host/reference/reference_pool.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +CK_TILE_HOST void reference_pool2d(const HostTensor& input, + HostTensor& output, + PoolKernelArgs kargs, + ReduceOp reduce_op) +{ + const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{}); + const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<1>{}); + const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<2>{}); + const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<3>{}); + + const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<1>{}); + const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<2>{}); + + const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<0>{}); + const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<1>{}); + + const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<0>{}); + const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<1>{}); + + const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<0>{}); + const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<1>{}); + + const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<0>{}); + const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<1>{}); + // Right padding is handled implicitly by bounds checking + + auto f = [&](auto n, auto ho, auto wo, auto c) { + ComputeDataType v_acc = reduce_op.template GetIdentityValue(); + + for(ck_tile::index_t y = 0; y < Y; ++y) + { + // Calculate input height index with stride, dilation, and padding + ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy; + + for(ck_tile::index_t x = 0; x < X; ++x) + { + // Calculate input width index with stride, dilation, and padding + ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx; + + if(hi >= 0 && hi < H && wi >= 0 && wi < W) + { + const ComputeDataType v_in = type_convert(input(n, hi, wi, c)); + v_acc = reduce_op(v_acc, v_in); + } + // For positions outside bounds, we implicitly use identity value + } + } + + output(n, ho, wo, c) = ck_tile::type_convert(v_acc); + }; + + // Parallelize over all output dimensions + make_ParallelTensorFunctor(f, N, Ho, Wo, C)(std::thread::hardware_concurrency()); +} + +template +CK_TILE_HOST void reference_pool3d(const HostTensor& input, + HostTensor& output, + PoolKernelArgs kargs, + ReduceOp reduce_op) +{ + const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{}); + const ck_tile::index_t D = kargs.input_shape.at(ck_tile::number<1>{}); + const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<2>{}); + const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<3>{}); + const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<4>{}); + + const ck_tile::index_t Do = kargs.output_shape.at(ck_tile::number<1>{}); + const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<2>{}); + const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<3>{}); + + const ck_tile::index_t Z = kargs.window_lengths.at(ck_tile::number<0>{}); + const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<1>{}); + const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<2>{}); + + const ck_tile::index_t Sz = kargs.window_strides.at(ck_tile::number<0>{}); + const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<1>{}); + const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<2>{}); + + const ck_tile::index_t Dz = kargs.window_dilations.at(ck_tile::number<0>{}); + const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<1>{}); + const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<2>{}); + + const ck_tile::index_t LeftPz = kargs.input_left_pads.at(ck_tile::number<0>{}); + const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<1>{}); + const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<2>{}); + // Right padding is handled implicitly by bounds checking + + auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) { + ComputeDataType v_acc = reduce_op.template GetIdentityValue(); + + for(ck_tile::index_t z = 0; z < Z; ++z) + { + // Calculate input depth index with stride, dilation, and padding + ck_tile::index_t di = do_ * Sz + z * Dz - LeftPz; + + for(ck_tile::index_t y = 0; y < Y; ++y) + { + // Calculate input height index with stride, dilation, and padding + ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy; + + for(ck_tile::index_t x = 0; x < X; ++x) + { + // Calculate input width index with stride, dilation, and padding + ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx; + + if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W) + { + const ComputeDataType v_in = + type_convert(input(n, di, hi, wi, c)); + v_acc = reduce_op(v_acc, v_in); + } + // For positions outside bounds, we implicitly use identity value + } + } + } + + output(n, do_, ho, wo, c) = ck_tile::type_convert(v_acc); + }; + + // Parallelize over all output dimensions + make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency()); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/pool.hpp b/include/ck_tile/ops/pool.hpp new file mode 100644 index 0000000000..350ef17dcb --- /dev/null +++ b/include/ck_tile/ops/pool.hpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp" +#include "ck_tile/ops/pooling/pipeline/pool_problem.hpp" +#include "ck_tile/ops/pooling/pipeline/pool_shape.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp b/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp new file mode 100644 index 0000000000..93567e7161 --- /dev/null +++ b/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp @@ -0,0 +1,496 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/pooling/pipeline/pool_default_policy.hpp" +#include "ck_tile/ops/common.hpp" +#include + +namespace ck_tile { + +/// @brief Host arguments for pooling operations +template +struct PoolHostArgs +{ + + CK_TILE_HOST PoolHostArgs(const void* input_ptr_, + void* output_ptr_, + TensorShape input_shape_, + TensorShape output_shape_, + TensorShape input_strides_, + TensorShape output_strides_, + WindowShape window_lengths_, + WindowShape window_strides_, + WindowShape window_dilations_, + WindowShape input_left_pads_, + WindowShape input_right_pads_) + : input_ptr(input_ptr_), + output_ptr(output_ptr_), + input_shape(input_shape_), + output_shape(output_shape_), + input_strides(input_strides_), + output_strides(output_strides_), + window_lengths(window_lengths_), + window_strides(window_strides_), + window_dilations(window_dilations_), + input_left_pads(input_left_pads_), + input_right_pads(input_right_pads_) + { + } + + const void* input_ptr; + void* output_ptr; + + TensorShape input_shape; + TensorShape output_shape; + TensorShape input_strides; + TensorShape output_strides; + WindowShape window_lengths; + WindowShape window_strides; + WindowShape window_dilations; + WindowShape input_left_pads; + WindowShape input_right_pads; +}; + +/// @brief Kernel arguments for pooling operations +template +struct PoolKernelArgs +{ + const void* input_ptr; + void* output_ptr; + TensorShape input_shape; + TensorShape output_shape; + TensorShape input_strides; + TensorShape output_strides; + WindowShape window_lengths; + WindowShape window_strides; + WindowShape window_dilations; + WindowShape input_left_pads; + WindowShape input_right_pads; +}; + +template +struct PoolKernel +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using InDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using OutDataType = ck_tile::remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? kBlockSize / 2 : kBlockSize; + } + + template + static CK_TILE_DEVICE auto MakeTensorView2D(PoolKernelArgs kargs) + { + using S = typename Problem::BlockShape; + + // Compile-time validation for 2D pooling + static_assert(TensorShape::size() == 4, "2D pooling requires 4D input tensor (N,H,W,C)"); + static_assert(WindowShape::size() == 2, "2D pooling requires 2D window shape (Y,X)"); + + // Extract dimension values + const index_t N = kargs.input_shape.at(number<0>{}); + const index_t H = kargs.input_shape.at(number<1>{}); + const index_t W = kargs.input_shape.at(number<2>{}); + const index_t C = kargs.input_shape.at(number<3>{}); + + const index_t No = kargs.output_shape.at(number<0>{}); + const index_t Ho = kargs.output_shape.at(number<1>{}); + const index_t Wo = kargs.output_shape.at(number<2>{}); + const index_t Co = kargs.output_shape.at(number<3>{}); + + const index_t Y = kargs.window_lengths.at(number<0>{}); + const index_t X = kargs.window_lengths.at(number<1>{}); + + const index_t WindowStrideH = kargs.window_strides.at(number<0>{}); + const index_t WindowStrideW = kargs.window_strides.at(number<1>{}); + + const index_t WindowDilationH = kargs.window_dilations.at(number<0>{}); + const index_t WindowDilationW = kargs.window_dilations.at(number<1>{}); + + const index_t InLeftPadH = kargs.input_left_pads.at(number<0>{}); + const index_t InLeftPadW = kargs.input_left_pads.at(number<1>{}); + + const index_t InRightPadH = kargs.input_right_pads.at(number<0>{}); + const index_t InRightPadW = kargs.input_right_pads.at(number<1>{}); + + const index_t MRaw = N * Ho * Wo * C; + const index_t KRaw = Y * X; + const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw; + const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw; + + auto reduce_op = typename Problem::ReduceOp{}; + + // Create input descriptor with all transformations + auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides); + + // Apply spatial padding to input descriptor + const auto padded_in_desc = transform_tensor_descriptor( + in_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(H, InLeftPadH, InRightPadH), + make_pad_transform(W, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + // Create sliding windows by embedding pooling windows into descriptor + const auto embed_in_desc = transform_tensor_descriptor( + padded_in_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)), + make_pass_through_transform(C)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + // Reshape into 2D matrix: output positions (M) x pooling window elements (K) + const auto merged_embed_in_desc = + transform_tensor_descriptor(embed_in_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)), + make_merge_transform(make_tuple(Y, X))), + make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto in_desc_padded = transform_tensor_descriptor( + merged_embed_in_desc, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // Create output descriptor with transformations + auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides); + + const auto merged_out_desc = transform_tensor_descriptor( + out_desc, + make_tuple(make_merge_transform(make_tuple(No, Ho, Wo, Co))), + make_tuple(sequence<0, 1, 2, 3>{}), + make_tuple(sequence<0>{})); + + const auto out_desc_padded = + transform_tensor_descriptor(merged_out_desc, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(sequence<0>{}), + make_tuple(sequence<0>{})); + + // Now create buffer views and tensor views with the fully transformed descriptors + const InDataType in_identity = + type_convert(reduce_op.template GetIdentityValue()); + const OutDataType out_identity = + type_convert(reduce_op.template GetIdentityValue()); + + auto in_buffer_view = make_buffer_view( + static_cast(kargs.input_ptr), + in_desc.get_element_space_size(), + in_identity); + const auto in_tensor_padded = + tensor_view{in_buffer_view, + in_desc_padded}; + + auto out_buffer_view = make_buffer_view( + static_cast(kargs.output_ptr), + out_desc.get_element_space_size(), + out_identity); + const auto out_tensor_padded = + tensor_view{out_buffer_view, + out_desc_padded}; + + return make_tuple(in_tensor_padded, out_tensor_padded); + } + + template + static CK_TILE_DEVICE auto MakeTensorView3D(PoolKernelArgs kargs) + { + using S = typename Problem::BlockShape; + + // Compile-time validation for 3D pooling + static_assert(TensorShape::size() == 5, "3D pooling requires 5D input tensor (N,D,H,W,C)"); + static_assert(WindowShape::size() == 3, "3D pooling requires 3D window shape (Z,Y,X)"); + + // Extract dimension values + const index_t N = kargs.input_shape.at(number<0>{}); + const index_t D = kargs.input_shape.at(number<1>{}); + const index_t H = kargs.input_shape.at(number<2>{}); + const index_t W = kargs.input_shape.at(number<3>{}); + const index_t C = kargs.input_shape.at(number<4>{}); + + const index_t No = kargs.output_shape.at(number<0>{}); + const index_t Do = kargs.output_shape.at(number<1>{}); + const index_t Ho = kargs.output_shape.at(number<2>{}); + const index_t Wo = kargs.output_shape.at(number<3>{}); + const index_t Co = kargs.output_shape.at(number<4>{}); + + const index_t Z = kargs.window_lengths.at(number<0>{}); + const index_t Y = kargs.window_lengths.at(number<1>{}); + const index_t X = kargs.window_lengths.at(number<2>{}); + + const index_t WindowStrideD = kargs.window_strides.at(number<0>{}); + const index_t WindowStrideH = kargs.window_strides.at(number<1>{}); + const index_t WindowStrideW = kargs.window_strides.at(number<2>{}); + + const index_t WindowDilationD = kargs.window_dilations.at(number<0>{}); + const index_t WindowDilationH = kargs.window_dilations.at(number<1>{}); + const index_t WindowDilationW = kargs.window_dilations.at(number<2>{}); + + const index_t InLeftPadD = kargs.input_left_pads.at(number<0>{}); + const index_t InLeftPadH = kargs.input_left_pads.at(number<1>{}); + const index_t InLeftPadW = kargs.input_left_pads.at(number<2>{}); + + const index_t InRightPadD = kargs.input_right_pads.at(number<0>{}); + const index_t InRightPadH = kargs.input_right_pads.at(number<1>{}); + const index_t InRightPadW = kargs.input_right_pads.at(number<2>{}); + + const index_t MRaw = N * Do * Ho * Wo * C; + const index_t KRaw = Z * Y * X; + const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw; + const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw; + + auto reduce_op = typename Problem::ReduceOp{}; + + // Create input descriptor with all transformations + auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides); + + // Apply spatial padding to input descriptor (all 3D dimensions) + const auto padded_in_desc = transform_tensor_descriptor( + in_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(D, InLeftPadD, InRightPadD), + make_pad_transform(H, InLeftPadH, InRightPadH), + make_pad_transform(W, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + // Create 3D sliding windows by embedding pooling windows into descriptor + const auto embed_in_desc = transform_tensor_descriptor( + padded_in_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)), + make_pass_through_transform(C)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + // Reshape into 2D matrix: output positions (M) x pooling window elements (K) + const auto merged_embed_in_desc = transform_tensor_descriptor( + embed_in_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)), + make_merge_transform(make_tuple(Z, Y, X))), + make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto in_desc_padded = transform_tensor_descriptor( + merged_embed_in_desc, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // Create output descriptor with transformations + auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides); + + const auto merged_out_desc = transform_tensor_descriptor( + out_desc, + make_tuple(make_merge_transform(make_tuple(No, Do, Ho, Wo, Co))), + make_tuple(sequence<0, 1, 2, 3, 4>{}), + make_tuple(sequence<0>{})); + + const auto out_desc_padded = + transform_tensor_descriptor(merged_out_desc, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(sequence<0>{}), + make_tuple(sequence<0>{})); + + // Now create buffer views and tensor views with the fully transformed descriptors + const InDataType in_identity = + type_convert(reduce_op.template GetIdentityValue()); + const OutDataType out_identity = + type_convert(reduce_op.template GetIdentityValue()); + + auto in_buffer_view = make_buffer_view( + static_cast(kargs.input_ptr), + in_desc.get_element_space_size(), + in_identity); + const auto in_tensor_padded = + tensor_view{in_buffer_view, + in_desc_padded}; + + auto out_buffer_view = make_buffer_view( + static_cast(kargs.output_ptr), + out_desc.get_element_space_size(), + out_identity); + const auto out_tensor_padded = + tensor_view{out_buffer_view, + out_desc_padded}; + + return make_tuple(in_tensor_padded, out_tensor_padded); + } + + public: + template + CK_TILE_DEVICE void operator()(PoolKernelArgs kargs) const + { + using S = typename Problem::BlockShape; + + // Compile-time validation for supported window dimensions + static_assert(WindowShape::size() == 2 || WindowShape::size() == 3, + "Only 2D and 3D pooling operations are supported"); + + const auto iM = get_block_id() * S::Block_M; + + // Get tensors based on dimensionality + auto [in_tensor_padded, out_tensor_padded] = [&]() { + if constexpr(WindowShape::size() == 2) + return MakeTensorView2D(kargs); + else if constexpr(WindowShape::size() == 3) + return MakeTensorView3D(kargs); + else + static_assert(WindowShape::size() == 2 || WindowShape::size() == 3, + "Unsupported WindowShape rank: only 2D or 3D pooling is supported"); + }(); + + auto reduce_op = typename Problem::ReduceOp{}; + + auto x_window = make_tile_window(in_tensor_padded, + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeXBlockTileDistribution()); + auto y_window = make_tile_window(out_tensor_padded, make_tuple(number{}), {iM}); + + __shared__ char smem[Policy::template GetSmemSize()]; + + const auto reduce_len = + in_tensor_padded.get_tensor_descriptor().get_lengths().at(number<1>{}); + index_t num_k_tiles = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(reduce_len, S::Block_N)); + + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp = Policy::template GetBlockReduce2dCrossWarpSync(); + + using XTensorTile = decltype(load_tile(x_window)); + auto y_tile = block_reduce2d.template MakeYBlockTile(); + set_tile(y_tile, reduce_op.template GetIdentityValue()); + + for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile) + { + const auto x_tile = load_tile(x_window); + block_reduce2d(x_tile, y_tile, reduce_op); + move_tile_window(x_window, {0, S::Block_N}); + } + + block_reduce2d_sync(y_tile, reduce_op); + block_reduce2d_cross_warp(y_tile, smem, reduce_op); + store_tile(y_window, cast_tile(y_tile)); + } + + /// @brief Validates if the given arguments are supported by the pooling kernel. + /// + /// @param kargs The pooling kernel arguments containing all necessary parameters. + /// + /// @return true if the arguments are supported, false otherwise. + /// + /// @note Requirements: + /// - Last dimension (C) must be contiguous (stride = 1) for vectorized access + /// - Window dimensions must be supported (2D or 3D) + /// - All dimension sizes must be consistent between input and output + template + CK_TILE_HOST static bool IsSupportedArgument(PoolKernelArgs kargs) + { + constexpr index_t InputRank = TensorShape::size(); + constexpr index_t OutputRank = TensorShape::size(); // Same as input rank + constexpr index_t WindowRank = WindowShape::size(); + + // Validate window dimensions (only 2D and 3D supported) + if constexpr(WindowRank != 2 && WindowRank != 3) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Only 2D and 3D pooling are supported!"); + } + return false; + } + + // Validate that input rank matches expected rank for window dimensions + if constexpr((WindowRank == 2 && InputRank != 4) || (WindowRank == 3 && InputRank != 5)) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Input tensor rank doesn't match window dimensions!"); + } + return false; + } + + // Check that channel dimension (last dimension) is contiguous for both input and output + if(kargs.input_strides.at(number{}) != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Input tensor's channel dimension must have stride 1!"); + } + return false; + } + + if(kargs.output_strides.at(number{}) != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Output tensor's channel dimension must have stride 1!"); + } + return false; + } + + return true; + } + + /// @param kargs The pooling kernel arguments + /// @return The calculated grid size + template + CK_TILE_HOST static constexpr index_t + CalculateGridSize(PoolKernelArgs kargs) + { + using S = typename Problem::BlockShape; + + // Calculate total output elements (M dimension) + index_t M = 1; + static_for<0, TensorShape::size(), 1>{}([&](auto i) { M *= kargs.output_shape.at(i); }); + + // Calculate grid size: ceil(M / Block_M) + return (M + S::Block_M - 1) / S::Block_M; + } + + /// @brief Create kernel arguments from host arguments + template + CK_TILE_HOST static constexpr auto + MakeKernelArgs(PoolHostArgs& host_args) + { + return PoolKernelArgs{host_args.input_ptr, + host_args.output_ptr, + host_args.input_shape, + host_args.output_shape, + host_args.input_strides, + host_args.output_strides, + host_args.window_lengths, + host_args.window_strides, + host_args.window_dilations, + host_args.input_left_pads, + host_args.input_right_pads}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp b/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp new file mode 100644 index 0000000000..a5b5fac63d --- /dev/null +++ b/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d.hpp" + +namespace ck_tile { + +struct PoolDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple< + sequence, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 2>>, + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2d{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dCrossWarpSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + if constexpr(Problem::kNeedCrossWarpSync) + { + using P_ = BlockReduce2dProblem; + + using block_reduce2d = BlockReduce2d; + using x_block_tile = + decltype(make_static_distributed_tensor( + MakeXBlockTileDistribution())); + using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile()); + + return GetBlockReduce2dCrossWarpSync().template GetSmemSize(); + } + else + { + return 1; // zero size arrays are an extension + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/pooling/pipeline/pool_problem.hpp b/include/ck_tile/ops/pooling/pipeline/pool_problem.hpp new file mode 100644 index 0000000000..83a43318bc --- /dev/null +++ b/include/ck_tile/ops/pooling/pipeline/pool_problem.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct PoolProblem +{ + using InDataType = remove_cvref_t; + using OutDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using IndexDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + using ReduceOp = ReduceOp_; + using OutputIndex = bool_constant; + using PropagateNan = bool_constant; + + static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; + static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/pooling/pipeline/pool_shape.hpp b/include/ck_tile/ops/pooling/pipeline/pool_shape.hpp new file mode 100644 index 0000000000..5879fe593e --- /dev/null +++ b/include/ck_tile/ops/pooling/pipeline/pool_shape.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template + typename BlockTile, // block size, seq + typename WarpTile, // warp size, seq + typename ThreadTile> // contiguous pixels(vector size) along seq +struct PoolShape +{ + static constexpr index_t Block_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_N = BlockTile::at(number<1>{}); + + static constexpr index_t Warp_M = WarpTile::at(number<0>{}); + static constexpr index_t Warp_N = WarpTile::at(number<1>{}); + + static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{}); + static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{}); + + static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{}); + static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{}); + + static_assert(Warp_M % ThreadTile_M == 0, "Warp_M must be divisible by ThreadTile_M"); + static_assert(Warp_N % ThreadTile_N == 0, "Warp_N must be divisible by ThreadTile_N"); + static_assert((Warp_M * Warp_N / ThreadTile_M / ThreadTile_N) % ck_tile::get_warp_size() == 0, + "Warp_M * Warp_N / ThreadTile_M / ThreadTile_N must be a multiple of warp size"); + + // Scale factor to account for warp size + // WarpSizeScaleFactor = warp tile/ thread tile / warp size + static constexpr index_t WarpSizeScaleFactor = + Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size(); + + static constexpr index_t WarpSizeScaleFactor_M = + (Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? WarpSizeScaleFactor : 1; + static constexpr index_t WarpSizeScaleFactor_N = + (Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : WarpSizeScaleFactor; + + static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / WarpSizeScaleFactor_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / WarpSizeScaleFactor_N; + + static_assert((Block_M * WarpSizeScaleFactor_M) % (WarpPerBlock_M * Warp_M) == 0, + "Block_M * WarpSizeScaleFactor_M must be divisible by WarpPerBlock_M * Warp_M"); + static_assert((Block_N * WarpSizeScaleFactor_N) % (WarpPerBlock_N * Warp_N) == 0, + "Block_N * WarpSizeScaleFactor_N must be divisible by WarpPerBlock_N * Warp_N"); + + static constexpr index_t Repeat_M = Block_M * WarpSizeScaleFactor_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N * WarpSizeScaleFactor_N / (WarpPerBlock_N * Warp_N); + + static constexpr index_t BlockSize = + ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); +}; +} // namespace ck_tile diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 04be25f30a..5fa6918c10 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -31,3 +31,4 @@ add_subdirectory(epilogue) add_subdirectory(atomic_add_op) add_subdirectory(fmha) add_subdirectory(gemm_tile_engine) +add_subdirectory(pooling) diff --git a/test/ck_tile/pooling/CMakeLists.txt b/test/ck_tile/pooling/CMakeLists.txt new file mode 100644 index 0000000000..83c36cb321 --- /dev/null +++ b/test/ck_tile/pooling/CMakeLists.txt @@ -0,0 +1,3 @@ +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + add_gtest_executable(test_ck_tile_pooling test_pooling.cpp) +endif() diff --git a/test/ck_tile/pooling/test_pooling.cpp b/test/ck_tile/pooling/test_pooling.cpp new file mode 100644 index 0000000000..3cec19d2d6 --- /dev/null +++ b/test/ck_tile/pooling/test_pooling.cpp @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/pool.hpp" +#include "ck_tile/host/reference/reference_pool.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +template +class TestCkTilePooling : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using OutDataType = std::tuple_element_t<1, Tuple>; + using ComputeDataType = std::tuple_element_t<2, Tuple>; + using ReduceOpType = std::tuple_element_t<3, Tuple>; + using BlockWarps_ = std::tuple_element_t<4, Tuple>; + using BlockTile_ = std::tuple_element_t<5, Tuple>; + using WarpTile_ = std::tuple_element_t<6, Tuple>; + using ThreadTile_ = std::tuple_element_t<7, Tuple>; + + using TestPoolShape = ck_tile::PoolShape; + + // 3D pooling configuration + struct Config3D + { + ck_tile::index_t N, D, H, W, C; + ck_tile::index_t Z, Y, X; + ck_tile::index_t Sz, Sy, Sx; + ck_tile::index_t Dz, Dy, Dx; + ck_tile::index_t LeftPz, LeftPy, LeftPx; + ck_tile::index_t RightPz, RightPy, RightPx; + std::string name; + }; + + bool RunPool3D(const Config3D& config) + { + std::cout << "Testing 3D: " << config.name << " ... "; + + const ck_tile::index_t Zs = (config.Z - 1) * config.Dz + 1; + const ck_tile::index_t Ys = (config.Y - 1) * config.Dy + 1; + const ck_tile::index_t Xs = (config.X - 1) * config.Dx + 1; + const ck_tile::index_t Do = + (config.D + config.LeftPz + config.RightPz - Zs) / config.Sz + 1; + const ck_tile::index_t Ho = + (config.H + config.LeftPy + config.RightPy - Ys) / config.Sy + 1; + const ck_tile::index_t Wo = + (config.W + config.LeftPx + config.RightPx - Xs) / config.Sx + 1; + + const auto input_shape = + ck_tile::make_tuple(config.N, config.D, config.H, config.W, config.C); + const auto output_shape = ck_tile::make_tuple(config.N, Do, Ho, Wo, config.C); + const auto input_strides = ck_tile::make_tuple(config.D * config.H * config.W * config.C, + config.H * config.W * config.C, + config.W * config.C, + config.C, + 1); + const auto output_strides = ck_tile::make_tuple( + Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1); + const auto window_spatial_lengths = ck_tile::make_tuple(config.Z, config.Y, config.X); + const auto window_strides = ck_tile::make_tuple(config.Sz, config.Sy, config.Sx); + const auto window_dilations = ck_tile::make_tuple(config.Dz, config.Dy, config.Dx); + const auto input_left_pads = + ck_tile::make_tuple(config.LeftPz, config.LeftPy, config.LeftPx); + const auto input_right_pads = + ck_tile::make_tuple(config.RightPz, config.RightPy, config.RightPx); + + ck_tile::HostTensor h_in({config.N, config.D, config.H, config.W, config.C}, + {config.D * config.H * config.W * config.C, + config.H * config.W * config.C, + config.W * config.C, + config.C, + 1}); + ck_tile::HostTensor h_out( + {config.N, Do, Ho, Wo, config.C}, + {Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1}); + ck_tile::HostTensor h_out_ref( + {config.N, Do, Ho, Wo, config.C}, + {Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1}); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(h_in); + h_out.SetZero(); + h_out_ref.SetZero(); + + ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes()); + + d_in_mem.ToDevice(h_in.data()); + d_out_mem.ToDevice(h_out.data()); + + using Problem = ck_tile::PoolProblem; + using Kernel = ck_tile::PoolKernel; + + constexpr ck_tile::index_t kBlockPerCu = 1; + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + + auto host_args = + ck_tile::PoolHostArgs{ + static_cast(d_in_mem.GetDeviceBuffer()), + static_cast(d_out_mem.GetDeviceBuffer()), + input_shape, + output_shape, + input_strides, + output_strides, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads}; + + auto kernel_args = Kernel::MakeKernelArgs(host_args); + + const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args); + + if(!Kernel::IsSupportedArgument(kernel_args)) + { + return true; + } + + // Run kernel + ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, false, 0}, + ck_tile::make_kernel(Kernel{}, kGridSize, kBlockSize, 0, kernel_args)); + + // Run reference implementation + ck_tile::reference_pool3d( + h_in, h_out_ref, kernel_args, ReduceOpType{}); + + d_out_mem.FromDevice(h_out.data()); + + // Validate results + bool pass = ck_tile::check_err(h_out, h_out_ref); + std::cout << (pass ? "PASS" : "FAIL") << std::endl; + + return pass; + } +}; + +using Shape1_BlockWarps = ck_tile::sequence<4, 1>; +using Shape1_BlockTile = ck_tile::sequence<128, 128>; +using Shape1_WarpTile = ck_tile::sequence<32, 128>; +using Shape1_ThreadTile = ck_tile::sequence<8, 8>; + +// Cross-warp configuration +using Shape2_BlockWarps = ck_tile::sequence<2, 2>; +using Shape2_BlockTile = ck_tile::sequence<2, 1024>; +using Shape2_WarpTile = ck_tile::sequence<1, 512>; +using Shape2_ThreadTile = ck_tile::sequence<1, 8>; + +// Test configurations for different data types and operations +using TestConfig_F32_Max = std::tuple; + +using TestConfig_F16_Max = std::tuple; + +using TestConfig_F32_CrossWarp = std::tuple; + +using TestTypes = + ::testing::Types; + +TYPED_TEST_SUITE(TestCkTilePooling, TestTypes); + +TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2) +{ + typename TestFixture::Config3D config = {1, // N - batch size + 4, // D - depth dimension + 4, // H - height dimension + 4, // W - width dimension + 32, // C - channel dimension + 2, // Z - pooling window depth + 2, // Y - pooling window height + 2, // X - pooling window width + 2, // Sz - window stride depth + 2, // Sy - window stride height + 2, // Sx - window stride width + 1, // Dz - window dilation depth + 1, // Dy - window dilation height + 1, // Dx - window dilation width + 0, // LeftPz - left padding depth + 0, // LeftPy - left padding height + 0, // LeftPx - left padding width + 0, // RightPz - right padding depth + 0, // RightPy - right padding height + 0, // RightPx - right padding width + "2x2x2 pooling"}; + bool pass = this->RunPool3D(config); + EXPECT_TRUE(pass); +} + +TYPED_TEST(TestCkTilePooling, Pool3D_3x3x3) +{ + typename TestFixture::Config3D config = {2, // N - batch size + 16, // D - depth dimension + 16, // H - height dimension + 16, // W - width dimension + 128, // C - channel dimension + 3, // Z - pooling window depth + 3, // Y - pooling window height + 3, // X - pooling window width + 2, // Sz - window stride depth + 2, // Sy - window stride height + 2, // Sx - window stride width + 1, // Dz - window dilation depth + 1, // Dy - window dilation height + 1, // Dx - window dilation width + 1, // LeftPz - left padding depth + 1, // LeftPy - left padding height + 1, // LeftPx - left padding width + 1, // RightPz - right padding depth + 1, // RightPy - right padding height + 1, // RightPx - right padding width + "3x3x3 pooling"}; + bool pass = this->RunPool3D(config); + EXPECT_TRUE(pass); +}