// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include #include #include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/ops/pool.hpp" #include "ck_tile/host/reference/reference_pool.hpp" #include "ck_tile/host/kernel_launch.hpp" template class TestCkTilePooling : public ::testing::Test { protected: using InDataType = std::tuple_element_t<0, Tuple>; using OutDataType = std::tuple_element_t<1, Tuple>; using ComputeDataType = std::tuple_element_t<2, Tuple>; using ReduceOpType = std::tuple_element_t<3, Tuple>; using BlockWarps_ = std::tuple_element_t<4, Tuple>; using BlockTile_ = std::tuple_element_t<5, Tuple>; using WarpTile_ = std::tuple_element_t<6, Tuple>; using ThreadTile_ = std::tuple_element_t<7, Tuple>; using TestPoolShape = ck_tile::PoolShape; // 3D pooling configuration struct Config3D { ck_tile::index_t N, D, H, W, C; ck_tile::index_t Z, Y, X; ck_tile::index_t Sz, Sy, Sx; ck_tile::index_t Dz, Dy, Dx; ck_tile::index_t LeftPz, LeftPy, LeftPx; ck_tile::index_t RightPz, RightPy, RightPx; std::string name; }; bool RunPool3D(const Config3D& config) { std::cout << "Testing 3D: " << config.name << " ... "; const ck_tile::index_t Zs = (config.Z - 1) * config.Dz + 1; const ck_tile::index_t Ys = (config.Y - 1) * config.Dy + 1; const ck_tile::index_t Xs = (config.X - 1) * config.Dx + 1; const ck_tile::index_t Do = (config.D + config.LeftPz + config.RightPz - Zs) / config.Sz + 1; const ck_tile::index_t Ho = (config.H + config.LeftPy + config.RightPy - Ys) / config.Sy + 1; const ck_tile::index_t Wo = (config.W + config.LeftPx + config.RightPx - Xs) / config.Sx + 1; const auto input_shape = ck_tile::make_tuple(config.N, config.D, config.H, config.W, config.C); const auto output_shape = ck_tile::make_tuple(config.N, Do, Ho, Wo, config.C); const auto input_strides = ck_tile::make_tuple(config.D * config.H * config.W * config.C, config.H * config.W * config.C, config.W * config.C, config.C, 1); const auto output_strides = ck_tile::make_tuple( Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1); const auto window_spatial_lengths = ck_tile::make_tuple(config.Z, config.Y, config.X); const auto window_strides = ck_tile::make_tuple(config.Sz, config.Sy, config.Sx); const auto window_dilations = ck_tile::make_tuple(config.Dz, config.Dy, config.Dx); const auto input_left_pads = ck_tile::make_tuple(config.LeftPz, config.LeftPy, config.LeftPx); const auto input_right_pads = ck_tile::make_tuple(config.RightPz, config.RightPy, config.RightPx); ck_tile::HostTensor h_in({config.N, config.D, config.H, config.W, config.C}, {config.D * config.H * config.W * config.C, config.H * config.W * config.C, config.W * config.C, config.C, 1}); ck_tile::HostTensor h_out( {config.N, Do, Ho, Wo, config.C}, {Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1}); ck_tile::HostTensor h_out_ref( {config.N, Do, Ho, Wo, config.C}, {Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1}); ck_tile::FillUniformDistribution{-5.f, 5.f}(h_in); h_out.SetZero(); h_out_ref.SetZero(); ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes()); ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes()); d_in_mem.ToDevice(h_in.data()); d_out_mem.ToDevice(h_out.data()); using Problem = ck_tile::PoolProblem; using Kernel = ck_tile::PoolKernel; constexpr ck_tile::index_t kBlockPerCu = 1; const ck_tile::index_t kBlockSize = Kernel::BlockSize(); auto host_args = ck_tile::PoolHostArgs{ static_cast(d_in_mem.GetDeviceBuffer()), static_cast(d_out_mem.GetDeviceBuffer()), input_shape, output_shape, input_strides, output_strides, window_spatial_lengths, window_strides, window_dilations, input_left_pads, input_right_pads}; auto kernel_args = Kernel::MakeKernelArgs(host_args); const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args); if(!Kernel::IsSupportedArgument(kernel_args)) { return true; } // Run kernel ck_tile::launch_kernel( ck_tile::stream_config{nullptr, false, 0}, ck_tile::make_kernel(Kernel{}, kGridSize, kBlockSize, 0, kernel_args)); // Run reference implementation ck_tile::reference_pool3d( h_in, h_out_ref, kernel_args, ReduceOpType{}); d_out_mem.FromDevice(h_out.data()); // Validate results bool pass = ck_tile::check_err(h_out, h_out_ref); std::cout << (pass ? "PASS" : "FAIL") << std::endl; return pass; } }; using Shape1_BlockWarps = ck_tile::sequence<4, 1>; using Shape1_BlockTile = ck_tile::sequence<128, 128>; using Shape1_WarpTile = ck_tile::sequence<32, 128>; using Shape1_ThreadTile = ck_tile::sequence<8, 8>; // Cross-warp configuration using Shape2_BlockWarps = ck_tile::sequence<2, 2>; using Shape2_BlockTile = ck_tile::sequence<2, 1024>; using Shape2_WarpTile = ck_tile::sequence<1, 512>; using Shape2_ThreadTile = ck_tile::sequence<1, 8>; // Test configurations for different data types and operations using TestConfig_F32_Max = std::tuple; using TestConfig_F16_Max = std::tuple; using TestConfig_F32_CrossWarp = std::tuple; using TestTypes = ::testing::Types; TYPED_TEST_SUITE(TestCkTilePooling, TestTypes); TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2) { typename TestFixture::Config3D config = {1, // N - batch size 4, // D - depth dimension 4, // H - height dimension 4, // W - width dimension 32, // C - channel dimension 2, // Z - pooling window depth 2, // Y - pooling window height 2, // X - pooling window width 2, // Sz - window stride depth 2, // Sy - window stride height 2, // Sx - window stride width 1, // Dz - window dilation depth 1, // Dy - window dilation height 1, // Dx - window dilation width 0, // LeftPz - left padding depth 0, // LeftPy - left padding height 0, // LeftPx - left padding width 0, // RightPz - right padding depth 0, // RightPy - right padding height 0, // RightPx - right padding width "2x2x2 pooling"}; bool pass = this->RunPool3D(config); EXPECT_TRUE(pass); } TYPED_TEST(TestCkTilePooling, Pool3D_3x3x3) { typename TestFixture::Config3D config = {2, // N - batch size 16, // D - depth dimension 16, // H - height dimension 16, // W - width dimension 128, // C - channel dimension 3, // Z - pooling window depth 3, // Y - pooling window height 3, // X - pooling window width 2, // Sz - window stride depth 2, // Sy - window stride height 2, // Sx - window stride width 1, // Dz - window dilation depth 1, // Dy - window dilation height 1, // Dx - window dilation width 1, // LeftPz - left padding depth 1, // LeftPy - left padding height 1, // LeftPx - left padding width 1, // RightPz - right padding depth 1, // RightPy - right padding height 1, // RightPx - right padding width "3x3x3 pooling"}; bool pass = this->RunPool3D(config); EXPECT_TRUE(pass); }