[CK_TILE] Pooling FWD (Lwpck 3683) (#2956)

* Pooling 2D/3D with refernce

* Tests & cleanup

- added test for ppoling
- cleanup
- removed 2d example

* Comment resolution

- README added
- example target name rectified
- appropriate arg description and comments added

* clang-format

* appropriate blocksize calc

* modifications for future indexing addition

- instead of transforming views we now transform the descriptors, so
that the same descriptor can be re-used for index tensor in the future

* some basic fixes

* comment resolutions

* comment resolutions

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Yashvardhan Agarwal
2025-10-09 17:13:26 +03:00
committed by GitHub
parent 9d4bfe3932
commit 7b6451b68e
14 changed files with 1317 additions and 0 deletions

View File

@@ -36,6 +36,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM.
* Added support for f32 to FMHA (fwd/bwd).
* Added tensor-wise quantization for CK_TILE GEMM.
* Added pooling kernel in CK_TILE
### Optimized

View File

@@ -0,0 +1,8 @@
set(EXAMPLE_POOL_3D "tile_example_pool3d")
message(DEBUG "adding example ${EXAMPLE_POOL_3D}")
add_executable(${EXAMPLE_POOL_3D} EXCLUDE_FROM_ALL pool3d.cpp)
target_include_directories(${EXAMPLE_POOL_3D} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_POOL_3D} PRIVATE ${EXAMPLE_POOL_COMPILE_OPTIONS})

View File

@@ -0,0 +1,42 @@
# Pooling Operator
This folder contains example for the pooling operator using ck_tile tile-programming implementation. Currently the pooling kernel only supports 2D and 3D pooling.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
# The 3D pooling example
make tile_example_pool3d -j`nproc`
```
This will result in an executable `build/bin/tile_example_pool3d`
## example
```
args:
-N batch size (default:2)
-D depth dimension (default:30)
-H height dimension (default:30)
-W width dimension (default:30)
-C channel dimension (default:32)
-Z pooling window depth (default:2)
-Y pooling window height (default:2)
-X pooling window width (default:2)
-Sz window stride depth (default:2)
-Sy window stride height (default:2)
-Sx window stride width (default:2)
-Dz window dilation depth (default:1)
-Dy window dilation height (default:1)
-Dx window dilation width (default:1)
-LeftPz left padding depth (default:1)
-LeftPy left padding height (default:1)
-LeftPx left padding width (default:1)
-RightPz right padding depth (default:1)
-RightPy right padding height (default:1)
-RightPx right padding width (default:1)
-v 0: No validation, 1: CPU validation (default:1)
-warmup number of iterations before benchmark (default:0)
-repeat number of iterations to benchmark (default:1)
```

View File

@@ -0,0 +1,188 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "ck_tile/ops/pool.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
#include <cstring>
// Parse command-line arguments for 3D pooling example
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("N", "2", "N dimension")
.insert("H", "30", "H dimension")
.insert("W", "30", "W dimension")
.insert("C", "32", "C dimension")
.insert("D", "30", "D dimension")
.insert("Z", "2", "Z dimension")
.insert("Y", "2", "Y dimension")
.insert("X", "2", "X dimension")
.insert("Sz", "2", "window stride d")
.insert("Sy", "2", "window stride h")
.insert("Sx", "2", "window stride w")
.insert("Dz", "1", "window dilation d")
.insert("Dy", "1", "window dilation h")
.insert("Dx", "1", "window dilation w")
.insert("LeftPz", "1", "left padding d")
.insert("LeftPy", "1", "left padding h")
.insert("LeftPx", "1", "left padding w")
.insert("RightPz", "1", "right padding d")
.insert("RightPy", "1", "right padding h")
.insert("RightPx", "1", "right padding w")
.insert("v", "1", "cpu validation or not")
.insert("warmup", "0", "cold iter")
.insert("repeat", "1", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename InDataType, typename OutDataType, typename ComputeDataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
const ck_tile::index_t N = arg_parser.get_int("N");
const ck_tile::index_t H = arg_parser.get_int("H");
const ck_tile::index_t W = arg_parser.get_int("W");
const ck_tile::index_t C = arg_parser.get_int("C");
const ck_tile::index_t D = arg_parser.get_int("D");
const ck_tile::index_t Z = arg_parser.get_int("Z");
const ck_tile::index_t Y = arg_parser.get_int("Y");
const ck_tile::index_t X = arg_parser.get_int("X");
const ck_tile::index_t Sz = arg_parser.get_int("Sz");
const ck_tile::index_t Sy = arg_parser.get_int("Sy");
const ck_tile::index_t Sx = arg_parser.get_int("Sx");
const ck_tile::index_t Dz = arg_parser.get_int("Dz");
const ck_tile::index_t Dy = arg_parser.get_int("Dy");
const ck_tile::index_t Dx = arg_parser.get_int("Dx");
const ck_tile::index_t LeftPz = arg_parser.get_int("LeftPz");
const ck_tile::index_t LeftPy = arg_parser.get_int("LeftPy");
const ck_tile::index_t LeftPx = arg_parser.get_int("LeftPx");
const ck_tile::index_t RightPz = arg_parser.get_int("RightPz");
const ck_tile::index_t RightPy = arg_parser.get_int("RightPy");
const ck_tile::index_t RightPx = arg_parser.get_int("RightPx");
const ck_tile::index_t Zs = (Z - 1) * Dz + 1;
const ck_tile::index_t Ys = (Y - 1) * Dy + 1;
const ck_tile::index_t Xs = (X - 1) * Dx + 1;
const ck_tile::index_t Do = (D + LeftPz + RightPz - Zs) / Sz + 1;
const ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
const ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
printf("Input parameters:\n");
printf("N: %d, D: %d, H: %d, W: %d, C: %d\n", N, D, H, W, C);
printf("Window Z: %d, Y: %d, X: %d, Stride Z: %d, Y: %d, X: %d\n", Z, Y, X, Sz, Sy, Sx);
printf("Output Do: %d, Ho: %d, Wo: %d\n", Do, Ho, Wo);
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
// Shapes / strides / parameters (NDHWC)
const auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
const auto input_strides = ck_tile::make_tuple(D * H * W * C, H * W * C, W * C, C, 1);
const auto output_strides = ck_tile::make_tuple(Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1);
const auto window_spatial_lengths = ck_tile::make_tuple(Z, Y, X);
const auto window_strides = ck_tile::make_tuple(Sz, Sy, Sx);
const auto window_dilations = ck_tile::make_tuple(Dz, Dy, Dx);
const auto input_left_pads = ck_tile::make_tuple(LeftPz, LeftPy, LeftPx);
const auto input_right_pads = ck_tile::make_tuple(RightPz, RightPy, RightPx);
ck_tile::HostTensor<InDataType> in({N, D, H, W, C}, {D * H * W * C, H * W * C, W * C, C, 1});
ck_tile::HostTensor<OutDataType> out({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<OutDataType> out_ref({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(in);
ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes());
in_buf.ToDevice(in.data());
using ReduceOp = ck_tile::ReduceOp::Max;
using BlockWarps = ck_tile::sequence<4, 1>;
using BlockTile = ck_tile::sequence<128, 128>;
using WarpTile = ck_tile::sequence<32, 128>;
using ThreadTile = ck_tile::sequence<8, 8>;
using Shape = ck_tile::PoolShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
using Problem = ck_tile::PoolProblem<InDataType,
OutDataType,
ComputeDataType,
OutDataType,
ReduceOp,
false,
false,
Shape>;
using Kernel = ck_tile::PoolKernel<Problem>;
constexpr ck_tile::index_t kBlockPerCu = 1;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
static_cast<InDataType*>(in_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_buf.GetDeviceBuffer()),
input_shape,
output_shape,
input_strides,
output_strides,
window_spatial_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
auto kernel_args = Kernel::MakeKernelArgs(host_args);
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
std::cout << "grid size " << kGridSize << std::endl;
// Validate kernel can handle the given configuration
if(!Kernel::IsSupportedArgument(kernel_args))
{
throw std::runtime_error("ERROR: Kernel arguments are not supported! \n");
}
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
std::size_t num_btype =
sizeof(InDataType) * N * D * H * W * C + sizeof(OutDataType) * N * Do * Ho * Wo * C;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(do_validation)
{
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
in, out_ref, kernel_args, ReduceOp{});
out_buf.FromDevice(out.mData.data());
pass = ck_tile::check_err(out, out_ref);
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
return run<ck_tile::half_t, ck_tile::half_t, float>(arg_parser) ? 0 : -2;
}

View File

@@ -23,6 +23,7 @@ add_subdirectory(20_grouped_convolution)
add_subdirectory(21_elementwise)
add_subdirectory(22_gemm_multi_abd)
add_subdirectory(35_batched_transpose)
add_subdirectory(36_pooling)
add_subdirectory(38_block_scale_gemm)
add_subdirectory(39_copy)
add_subdirectory(40_streamk_gemm)

View File

@@ -0,0 +1,147 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename InDataType,
typename ComputeDataType,
typename OutDataType,
typename ReduceOp,
typename TensorShape,
typename WindowShape>
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
HostTensor<OutDataType>& output,
PoolKernelArgs<TensorShape, WindowShape> kargs,
ReduceOp reduce_op)
{
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<1>{});
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<2>{});
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<3>{});
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<1>{});
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<2>{});
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<0>{});
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<1>{});
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<0>{});
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<1>{});
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<0>{});
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<1>{});
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<0>{});
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<1>{});
// Right padding is handled implicitly by bounds checking
auto f = [&](auto n, auto ho, auto wo, auto c) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
for(ck_tile::index_t y = 0; y < Y; ++y)
{
// Calculate input height index with stride, dilation, and padding
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
for(ck_tile::index_t x = 0; x < X; ++x)
{
// Calculate input width index with stride, dilation, and padding
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
{
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
v_acc = reduce_op(v_acc, v_in);
}
// For positions outside bounds, we implicitly use identity value
}
}
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
};
// Parallelize over all output dimensions
make_ParallelTensorFunctor(f, N, Ho, Wo, C)(std::thread::hardware_concurrency());
}
template <typename InDataType,
typename ComputeDataType,
typename OutDataType,
typename ReduceOp,
typename TensorShape,
typename WindowShape>
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
HostTensor<OutDataType>& output,
PoolKernelArgs<TensorShape, WindowShape> kargs,
ReduceOp reduce_op)
{
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
const ck_tile::index_t D = kargs.input_shape.at(ck_tile::number<1>{});
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<2>{});
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<3>{});
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<4>{});
const ck_tile::index_t Do = kargs.output_shape.at(ck_tile::number<1>{});
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<2>{});
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<3>{});
const ck_tile::index_t Z = kargs.window_lengths.at(ck_tile::number<0>{});
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<1>{});
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<2>{});
const ck_tile::index_t Sz = kargs.window_strides.at(ck_tile::number<0>{});
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<1>{});
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<2>{});
const ck_tile::index_t Dz = kargs.window_dilations.at(ck_tile::number<0>{});
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<1>{});
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<2>{});
const ck_tile::index_t LeftPz = kargs.input_left_pads.at(ck_tile::number<0>{});
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<1>{});
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<2>{});
// Right padding is handled implicitly by bounds checking
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
for(ck_tile::index_t z = 0; z < Z; ++z)
{
// Calculate input depth index with stride, dilation, and padding
ck_tile::index_t di = do_ * Sz + z * Dz - LeftPz;
for(ck_tile::index_t y = 0; y < Y; ++y)
{
// Calculate input height index with stride, dilation, and padding
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
for(ck_tile::index_t x = 0; x < X; ++x)
{
// Calculate input width index with stride, dilation, and padding
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
{
const ComputeDataType v_in =
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
v_acc = reduce_op(v_acc, v_in);
}
// For positions outside bounds, we implicitly use identity value
}
}
}
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
};
// Parallelize over all output dimensions
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
#include "ck_tile/ops/pooling/pipeline/pool_problem.hpp"
#include "ck_tile/ops/pooling/pipeline/pool_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -0,0 +1,496 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/pooling/pipeline/pool_default_policy.hpp"
#include "ck_tile/ops/common.hpp"
#include <type_traits>
namespace ck_tile {
/// @brief Host arguments for pooling operations
template <typename TensorShape, typename WindowShape>
struct PoolHostArgs
{
CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
void* output_ptr_,
TensorShape input_shape_,
TensorShape output_shape_,
TensorShape input_strides_,
TensorShape output_strides_,
WindowShape window_lengths_,
WindowShape window_strides_,
WindowShape window_dilations_,
WindowShape input_left_pads_,
WindowShape input_right_pads_)
: input_ptr(input_ptr_),
output_ptr(output_ptr_),
input_shape(input_shape_),
output_shape(output_shape_),
input_strides(input_strides_),
output_strides(output_strides_),
window_lengths(window_lengths_),
window_strides(window_strides_),
window_dilations(window_dilations_),
input_left_pads(input_left_pads_),
input_right_pads(input_right_pads_)
{
}
const void* input_ptr;
void* output_ptr;
TensorShape input_shape;
TensorShape output_shape;
TensorShape input_strides;
TensorShape output_strides;
WindowShape window_lengths;
WindowShape window_strides;
WindowShape window_dilations;
WindowShape input_left_pads;
WindowShape input_right_pads;
};
/// @brief Kernel arguments for pooling operations
template <typename TensorShape, typename WindowShape>
struct PoolKernelArgs
{
const void* input_ptr;
void* output_ptr;
TensorShape input_shape;
TensorShape output_shape;
TensorShape input_strides;
TensorShape output_strides;
WindowShape window_lengths;
WindowShape window_strides;
WindowShape window_dilations;
WindowShape input_left_pads;
WindowShape input_right_pads;
};
template <typename Problem_, typename Policy_ = PoolDefaultPolicy>
struct PoolKernel
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using InDataType = ck_tile::remove_cvref_t<typename Problem::InDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using OutDataType = ck_tile::remove_cvref_t<typename Problem::OutDataType>;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? kBlockSize / 2 : kBlockSize;
}
template <typename TensorShape, typename WindowShape>
static CK_TILE_DEVICE auto MakeTensorView2D(PoolKernelArgs<TensorShape, WindowShape> kargs)
{
using S = typename Problem::BlockShape;
// Compile-time validation for 2D pooling
static_assert(TensorShape::size() == 4, "2D pooling requires 4D input tensor (N,H,W,C)");
static_assert(WindowShape::size() == 2, "2D pooling requires 2D window shape (Y,X)");
// Extract dimension values
const index_t N = kargs.input_shape.at(number<0>{});
const index_t H = kargs.input_shape.at(number<1>{});
const index_t W = kargs.input_shape.at(number<2>{});
const index_t C = kargs.input_shape.at(number<3>{});
const index_t No = kargs.output_shape.at(number<0>{});
const index_t Ho = kargs.output_shape.at(number<1>{});
const index_t Wo = kargs.output_shape.at(number<2>{});
const index_t Co = kargs.output_shape.at(number<3>{});
const index_t Y = kargs.window_lengths.at(number<0>{});
const index_t X = kargs.window_lengths.at(number<1>{});
const index_t WindowStrideH = kargs.window_strides.at(number<0>{});
const index_t WindowStrideW = kargs.window_strides.at(number<1>{});
const index_t WindowDilationH = kargs.window_dilations.at(number<0>{});
const index_t WindowDilationW = kargs.window_dilations.at(number<1>{});
const index_t InLeftPadH = kargs.input_left_pads.at(number<0>{});
const index_t InLeftPadW = kargs.input_left_pads.at(number<1>{});
const index_t InRightPadH = kargs.input_right_pads.at(number<0>{});
const index_t InRightPadW = kargs.input_right_pads.at(number<1>{});
const index_t MRaw = N * Ho * Wo * C;
const index_t KRaw = Y * X;
const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
auto reduce_op = typename Problem::ReduceOp{};
// Create input descriptor with all transformations
auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
// Apply spatial padding to input descriptor
const auto padded_in_desc = transform_tensor_descriptor(
in_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(H, InLeftPadH, InRightPadH),
make_pad_transform(W, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
// Create sliding windows by embedding pooling windows into descriptor
const auto embed_in_desc = transform_tensor_descriptor(
padded_in_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
make_pass_through_transform(C)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
// Reshape into 2D matrix: output positions (M) x pooling window elements (K)
const auto merged_embed_in_desc =
transform_tensor_descriptor(embed_in_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)),
make_merge_transform(make_tuple(Y, X))),
make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto in_desc_padded = transform_tensor_descriptor(
merged_embed_in_desc,
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// Create output descriptor with transformations
auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
const auto merged_out_desc = transform_tensor_descriptor(
out_desc,
make_tuple(make_merge_transform(make_tuple(No, Ho, Wo, Co))),
make_tuple(sequence<0, 1, 2, 3>{}),
make_tuple(sequence<0>{}));
const auto out_desc_padded =
transform_tensor_descriptor(merged_out_desc,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(sequence<0>{}),
make_tuple(sequence<0>{}));
// Now create buffer views and tensor views with the fully transformed descriptors
const InDataType in_identity =
type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
const OutDataType out_identity =
type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
auto in_buffer_view = make_buffer_view<address_space_enum::global>(
static_cast<const InDataType*>(kargs.input_ptr),
in_desc.get_element_space_size(),
in_identity);
const auto in_tensor_padded =
tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
in_desc_padded};
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
static_cast<OutDataType*>(kargs.output_ptr),
out_desc.get_element_space_size(),
out_identity);
const auto out_tensor_padded =
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
out_desc_padded};
return make_tuple(in_tensor_padded, out_tensor_padded);
}
template <typename TensorShape, typename WindowShape>
static CK_TILE_DEVICE auto MakeTensorView3D(PoolKernelArgs<TensorShape, WindowShape> kargs)
{
using S = typename Problem::BlockShape;
// Compile-time validation for 3D pooling
static_assert(TensorShape::size() == 5, "3D pooling requires 5D input tensor (N,D,H,W,C)");
static_assert(WindowShape::size() == 3, "3D pooling requires 3D window shape (Z,Y,X)");
// Extract dimension values
const index_t N = kargs.input_shape.at(number<0>{});
const index_t D = kargs.input_shape.at(number<1>{});
const index_t H = kargs.input_shape.at(number<2>{});
const index_t W = kargs.input_shape.at(number<3>{});
const index_t C = kargs.input_shape.at(number<4>{});
const index_t No = kargs.output_shape.at(number<0>{});
const index_t Do = kargs.output_shape.at(number<1>{});
const index_t Ho = kargs.output_shape.at(number<2>{});
const index_t Wo = kargs.output_shape.at(number<3>{});
const index_t Co = kargs.output_shape.at(number<4>{});
const index_t Z = kargs.window_lengths.at(number<0>{});
const index_t Y = kargs.window_lengths.at(number<1>{});
const index_t X = kargs.window_lengths.at(number<2>{});
const index_t WindowStrideD = kargs.window_strides.at(number<0>{});
const index_t WindowStrideH = kargs.window_strides.at(number<1>{});
const index_t WindowStrideW = kargs.window_strides.at(number<2>{});
const index_t WindowDilationD = kargs.window_dilations.at(number<0>{});
const index_t WindowDilationH = kargs.window_dilations.at(number<1>{});
const index_t WindowDilationW = kargs.window_dilations.at(number<2>{});
const index_t InLeftPadD = kargs.input_left_pads.at(number<0>{});
const index_t InLeftPadH = kargs.input_left_pads.at(number<1>{});
const index_t InLeftPadW = kargs.input_left_pads.at(number<2>{});
const index_t InRightPadD = kargs.input_right_pads.at(number<0>{});
const index_t InRightPadH = kargs.input_right_pads.at(number<1>{});
const index_t InRightPadW = kargs.input_right_pads.at(number<2>{});
const index_t MRaw = N * Do * Ho * Wo * C;
const index_t KRaw = Z * Y * X;
const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
auto reduce_op = typename Problem::ReduceOp{};
// Create input descriptor with all transformations
auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
// Apply spatial padding to input descriptor (all 3D dimensions)
const auto padded_in_desc = transform_tensor_descriptor(
in_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(D, InLeftPadD, InRightPadD),
make_pad_transform(H, InLeftPadH, InRightPadH),
make_pad_transform(W, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
// Create 3D sliding windows by embedding pooling windows into descriptor
const auto embed_in_desc = transform_tensor_descriptor(
padded_in_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
make_pass_through_transform(C)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{},
sequence<1, 2>{},
sequence<3, 4>{},
sequence<5, 6>{},
sequence<7>{}));
// Reshape into 2D matrix: output positions (M) x pooling window elements (K)
const auto merged_embed_in_desc = transform_tensor_descriptor(
embed_in_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)),
make_merge_transform(make_tuple(Z, Y, X))),
make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto in_desc_padded = transform_tensor_descriptor(
merged_embed_in_desc,
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// Create output descriptor with transformations
auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
const auto merged_out_desc = transform_tensor_descriptor(
out_desc,
make_tuple(make_merge_transform(make_tuple(No, Do, Ho, Wo, Co))),
make_tuple(sequence<0, 1, 2, 3, 4>{}),
make_tuple(sequence<0>{}));
const auto out_desc_padded =
transform_tensor_descriptor(merged_out_desc,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(sequence<0>{}),
make_tuple(sequence<0>{}));
// Now create buffer views and tensor views with the fully transformed descriptors
const InDataType in_identity =
type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
const OutDataType out_identity =
type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
auto in_buffer_view = make_buffer_view<address_space_enum::global>(
static_cast<const InDataType*>(kargs.input_ptr),
in_desc.get_element_space_size(),
in_identity);
const auto in_tensor_padded =
tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
in_desc_padded};
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
static_cast<OutDataType*>(kargs.output_ptr),
out_desc.get_element_space_size(),
out_identity);
const auto out_tensor_padded =
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
out_desc_padded};
return make_tuple(in_tensor_padded, out_tensor_padded);
}
public:
template <typename TensorShape, typename WindowShape>
CK_TILE_DEVICE void operator()(PoolKernelArgs<TensorShape, WindowShape> kargs) const
{
using S = typename Problem::BlockShape;
// Compile-time validation for supported window dimensions
static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
"Only 2D and 3D pooling operations are supported");
const auto iM = get_block_id() * S::Block_M;
// Get tensors based on dimensionality
auto [in_tensor_padded, out_tensor_padded] = [&]() {
if constexpr(WindowShape::size() == 2)
return MakeTensorView2D(kargs);
else if constexpr(WindowShape::size() == 3)
return MakeTensorView3D(kargs);
else
static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
"Unsupported WindowShape rank: only 2D or 3D pooling is supported");
}();
auto reduce_op = typename Problem::ReduceOp{};
auto x_window = make_tile_window(in_tensor_padded,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{iM, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto y_window = make_tile_window(out_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
__shared__ char smem[Policy::template GetSmemSize<Problem>()];
const auto reduce_len =
in_tensor_padded.get_tensor_descriptor().get_lengths().at(number<1>{});
index_t num_k_tiles =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(reduce_len, S::Block_N));
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp = Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using XTensorTile = decltype(load_tile(x_window));
auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
{
const auto x_tile = load_tile(x_window);
block_reduce2d(x_tile, y_tile, reduce_op);
move_tile_window(x_window, {0, S::Block_N});
}
block_reduce2d_sync(y_tile, reduce_op);
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
store_tile(y_window, cast_tile<OutDataType>(y_tile));
}
/// @brief Validates if the given arguments are supported by the pooling kernel.
///
/// @param kargs The pooling kernel arguments containing all necessary parameters.
///
/// @return true if the arguments are supported, false otherwise.
///
/// @note Requirements:
/// - Last dimension (C) must be contiguous (stride = 1) for vectorized access
/// - Window dimensions must be supported (2D or 3D)
/// - All dimension sizes must be consistent between input and output
template <typename TensorShape, typename WindowShape>
CK_TILE_HOST static bool IsSupportedArgument(PoolKernelArgs<TensorShape, WindowShape> kargs)
{
constexpr index_t InputRank = TensorShape::size();
constexpr index_t OutputRank = TensorShape::size(); // Same as input rank
constexpr index_t WindowRank = WindowShape::size();
// Validate window dimensions (only 2D and 3D supported)
if constexpr(WindowRank != 2 && WindowRank != 3)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Only 2D and 3D pooling are supported!");
}
return false;
}
// Validate that input rank matches expected rank for window dimensions
if constexpr((WindowRank == 2 && InputRank != 4) || (WindowRank == 3 && InputRank != 5))
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Input tensor rank doesn't match window dimensions!");
}
return false;
}
// Check that channel dimension (last dimension) is contiguous for both input and output
if(kargs.input_strides.at(number<InputRank - 1>{}) != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Input tensor's channel dimension must have stride 1!");
}
return false;
}
if(kargs.output_strides.at(number<OutputRank - 1>{}) != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Output tensor's channel dimension must have stride 1!");
}
return false;
}
return true;
}
/// @param kargs The pooling kernel arguments
/// @return The calculated grid size
template <typename TensorShape, typename WindowShape>
CK_TILE_HOST static constexpr index_t
CalculateGridSize(PoolKernelArgs<TensorShape, WindowShape> kargs)
{
using S = typename Problem::BlockShape;
// Calculate total output elements (M dimension)
index_t M = 1;
static_for<0, TensorShape::size(), 1>{}([&](auto i) { M *= kargs.output_shape.at(i); });
// Calculate grid size: ceil(M / Block_M)
return (M + S::Block_M - 1) / S::Block_M;
}
/// @brief Create kernel arguments from host arguments
template <typename TensorShape, typename WindowShape>
CK_TILE_HOST static constexpr auto
MakeKernelArgs(PoolHostArgs<TensorShape, WindowShape>& host_args)
{
return PoolKernelArgs<TensorShape, WindowShape>{host_args.input_ptr,
host_args.output_ptr,
host_args.input_shape,
host_args.output_shape,
host_args.input_strides,
host_args.output_strides,
host_args.window_lengths,
host_args.window_strides,
host_args.window_dilations,
host_args.input_left_pads,
host_args.input_right_pads};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,80 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
namespace ck_tile {
struct PoolDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2d<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::InDataType>(
MakeXBlockTileDistribution<Problem>()));
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
}
else
{
return 1; // zero size arrays are an extension
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename InDataType_,
typename OutDataType_,
typename ComputeDataType_,
typename IndexDataType_,
typename ReduceOp_,
bool OutputIndex_,
bool PropagateNan_,
typename BlockShape_>
struct PoolProblem
{
using InDataType = remove_cvref_t<InDataType_>;
using OutDataType = remove_cvref_t<OutDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using IndexDataType = remove_cvref_t<IndexDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using ReduceOp = ReduceOp_;
using OutputIndex = bool_constant<OutputIndex_>;
using PropagateNan = bool_constant<PropagateNan_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
};
} // namespace ck_tile

View File

@@ -0,0 +1,57 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockWarps, // num warps along seq<M, N>
typename BlockTile, // block size, seq<M, N>
typename WarpTile, // warp size, seq<M, N>
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
struct PoolShape
{
static constexpr index_t Block_M = BlockTile::at(number<0>{});
static constexpr index_t Block_N = BlockTile::at(number<1>{});
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{});
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
static_assert(Warp_M % ThreadTile_M == 0, "Warp_M must be divisible by ThreadTile_M");
static_assert(Warp_N % ThreadTile_N == 0, "Warp_N must be divisible by ThreadTile_N");
static_assert((Warp_M * Warp_N / ThreadTile_M / ThreadTile_N) % ck_tile::get_warp_size() == 0,
"Warp_M * Warp_N / ThreadTile_M / ThreadTile_N must be a multiple of warp size");
// Scale factor to account for warp size
// WarpSizeScaleFactor = warp tile/ thread tile / warp size
static constexpr index_t WarpSizeScaleFactor =
Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
static constexpr index_t WarpSizeScaleFactor_M =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? WarpSizeScaleFactor : 1;
static constexpr index_t WarpSizeScaleFactor_N =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : WarpSizeScaleFactor;
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / WarpSizeScaleFactor_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / WarpSizeScaleFactor_N;
static_assert((Block_M * WarpSizeScaleFactor_M) % (WarpPerBlock_M * Warp_M) == 0,
"Block_M * WarpSizeScaleFactor_M must be divisible by WarpPerBlock_M * Warp_M");
static_assert((Block_N * WarpSizeScaleFactor_N) % (WarpPerBlock_N * Warp_N) == 0,
"Block_N * WarpSizeScaleFactor_N must be divisible by WarpPerBlock_N * Warp_N");
static constexpr index_t Repeat_M = Block_M * WarpSizeScaleFactor_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N * WarpSizeScaleFactor_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
};
} // namespace ck_tile

View File

@@ -31,3 +31,4 @@ add_subdirectory(epilogue)
add_subdirectory(atomic_add_op)
add_subdirectory(fmha)
add_subdirectory(gemm_tile_engine)
add_subdirectory(pooling)

View File

@@ -0,0 +1,3 @@
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_pooling test_pooling.cpp)
endif()

View File

@@ -0,0 +1,249 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <vector>
#include <cmath>
#include <tuple>
#include <iostream>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/pool.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
#include "ck_tile/host/kernel_launch.hpp"
template <typename Tuple>
class TestCkTilePooling : public ::testing::Test
{
protected:
using InDataType = std::tuple_element_t<0, Tuple>;
using OutDataType = std::tuple_element_t<1, Tuple>;
using ComputeDataType = std::tuple_element_t<2, Tuple>;
using ReduceOpType = std::tuple_element_t<3, Tuple>;
using BlockWarps_ = std::tuple_element_t<4, Tuple>;
using BlockTile_ = std::tuple_element_t<5, Tuple>;
using WarpTile_ = std::tuple_element_t<6, Tuple>;
using ThreadTile_ = std::tuple_element_t<7, Tuple>;
using TestPoolShape = ck_tile::PoolShape<BlockWarps_, BlockTile_, WarpTile_, ThreadTile_>;
// 3D pooling configuration
struct Config3D
{
ck_tile::index_t N, D, H, W, C;
ck_tile::index_t Z, Y, X;
ck_tile::index_t Sz, Sy, Sx;
ck_tile::index_t Dz, Dy, Dx;
ck_tile::index_t LeftPz, LeftPy, LeftPx;
ck_tile::index_t RightPz, RightPy, RightPx;
std::string name;
};
bool RunPool3D(const Config3D& config)
{
std::cout << "Testing 3D: " << config.name << " ... ";
const ck_tile::index_t Zs = (config.Z - 1) * config.Dz + 1;
const ck_tile::index_t Ys = (config.Y - 1) * config.Dy + 1;
const ck_tile::index_t Xs = (config.X - 1) * config.Dx + 1;
const ck_tile::index_t Do =
(config.D + config.LeftPz + config.RightPz - Zs) / config.Sz + 1;
const ck_tile::index_t Ho =
(config.H + config.LeftPy + config.RightPy - Ys) / config.Sy + 1;
const ck_tile::index_t Wo =
(config.W + config.LeftPx + config.RightPx - Xs) / config.Sx + 1;
const auto input_shape =
ck_tile::make_tuple(config.N, config.D, config.H, config.W, config.C);
const auto output_shape = ck_tile::make_tuple(config.N, Do, Ho, Wo, config.C);
const auto input_strides = ck_tile::make_tuple(config.D * config.H * config.W * config.C,
config.H * config.W * config.C,
config.W * config.C,
config.C,
1);
const auto output_strides = ck_tile::make_tuple(
Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1);
const auto window_spatial_lengths = ck_tile::make_tuple(config.Z, config.Y, config.X);
const auto window_strides = ck_tile::make_tuple(config.Sz, config.Sy, config.Sx);
const auto window_dilations = ck_tile::make_tuple(config.Dz, config.Dy, config.Dx);
const auto input_left_pads =
ck_tile::make_tuple(config.LeftPz, config.LeftPy, config.LeftPx);
const auto input_right_pads =
ck_tile::make_tuple(config.RightPz, config.RightPy, config.RightPx);
ck_tile::HostTensor<InDataType> h_in({config.N, config.D, config.H, config.W, config.C},
{config.D * config.H * config.W * config.C,
config.H * config.W * config.C,
config.W * config.C,
config.C,
1});
ck_tile::HostTensor<OutDataType> h_out(
{config.N, Do, Ho, Wo, config.C},
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
ck_tile::HostTensor<OutDataType> h_out_ref(
{config.N, Do, Ho, Wo, config.C},
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
h_out.SetZero();
h_out_ref.SetZero();
ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes());
d_in_mem.ToDevice(h_in.data());
d_out_mem.ToDevice(h_out.data());
using Problem = ck_tile::PoolProblem<InDataType,
OutDataType,
ComputeDataType,
OutDataType,
ReduceOpType,
false,
false,
TestPoolShape>;
using Kernel = ck_tile::PoolKernel<Problem>;
constexpr ck_tile::index_t kBlockPerCu = 1;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
auto host_args =
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
static_cast<InDataType*>(d_in_mem.GetDeviceBuffer()),
static_cast<OutDataType*>(d_out_mem.GetDeviceBuffer()),
input_shape,
output_shape,
input_strides,
output_strides,
window_spatial_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
auto kernel_args = Kernel::MakeKernelArgs(host_args);
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
if(!Kernel::IsSupportedArgument(kernel_args))
{
return true;
}
// Run kernel
ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, false, 0},
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
// Run reference implementation
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
h_in, h_out_ref, kernel_args, ReduceOpType{});
d_out_mem.FromDevice(h_out.data());
// Validate results
bool pass = ck_tile::check_err(h_out, h_out_ref);
std::cout << (pass ? "PASS" : "FAIL") << std::endl;
return pass;
}
};
using Shape1_BlockWarps = ck_tile::sequence<4, 1>;
using Shape1_BlockTile = ck_tile::sequence<128, 128>;
using Shape1_WarpTile = ck_tile::sequence<32, 128>;
using Shape1_ThreadTile = ck_tile::sequence<8, 8>;
// Cross-warp configuration
using Shape2_BlockWarps = ck_tile::sequence<2, 2>;
using Shape2_BlockTile = ck_tile::sequence<2, 1024>;
using Shape2_WarpTile = ck_tile::sequence<1, 512>;
using Shape2_ThreadTile = ck_tile::sequence<1, 8>;
// Test configurations for different data types and operations
using TestConfig_F32_Max = std::tuple<float,
float,
float,
ck_tile::ReduceOp::Max,
Shape1_BlockWarps,
Shape1_BlockTile,
Shape1_WarpTile,
Shape1_ThreadTile>;
using TestConfig_F16_Max = std::tuple<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::ReduceOp::Max,
Shape1_BlockWarps,
Shape1_BlockTile,
Shape1_WarpTile,
Shape1_ThreadTile>;
using TestConfig_F32_CrossWarp = std::tuple<float,
float,
float,
ck_tile::ReduceOp::Max,
Shape2_BlockWarps,
Shape2_BlockTile,
Shape2_WarpTile,
Shape2_ThreadTile>;
using TestTypes =
::testing::Types<TestConfig_F32_Max, TestConfig_F16_Max, TestConfig_F32_CrossWarp>;
TYPED_TEST_SUITE(TestCkTilePooling, TestTypes);
TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2)
{
typename TestFixture::Config3D config = {1, // N - batch size
4, // D - depth dimension
4, // H - height dimension
4, // W - width dimension
32, // C - channel dimension
2, // Z - pooling window depth
2, // Y - pooling window height
2, // X - pooling window width
2, // Sz - window stride depth
2, // Sy - window stride height
2, // Sx - window stride width
1, // Dz - window dilation depth
1, // Dy - window dilation height
1, // Dx - window dilation width
0, // LeftPz - left padding depth
0, // LeftPy - left padding height
0, // LeftPx - left padding width
0, // RightPz - right padding depth
0, // RightPy - right padding height
0, // RightPx - right padding width
"2x2x2 pooling"};
bool pass = this->RunPool3D(config);
EXPECT_TRUE(pass);
}
TYPED_TEST(TestCkTilePooling, Pool3D_3x3x3)
{
typename TestFixture::Config3D config = {2, // N - batch size
16, // D - depth dimension
16, // H - height dimension
16, // W - width dimension
128, // C - channel dimension
3, // Z - pooling window depth
3, // Y - pooling window height
3, // X - pooling window width
2, // Sz - window stride depth
2, // Sy - window stride height
2, // Sx - window stride width
1, // Dz - window dilation depth
1, // Dy - window dilation height
1, // Dx - window dilation width
1, // LeftPz - left padding depth
1, // LeftPy - left padding height
1, // LeftPx - left padding width
1, // RightPz - right padding depth
1, // RightPy - right padding height
1, // RightPx - right padding width
"3x3x3 pooling"};
bool pass = this->RunPool3D(config);
EXPECT_TRUE(pass);
}