mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK_TILE] Pooling FWD (Lwpck 3683) (#2956)
* Pooling 2D/3D with refernce * Tests & cleanup - added test for ppoling - cleanup - removed 2d example * Comment resolution - README added - example target name rectified - appropriate arg description and comments added * clang-format * appropriate blocksize calc * modifications for future indexing addition - instead of transforming views we now transform the descriptors, so that the same descriptor can be re-used for index tensor in the future * some basic fixes * comment resolutions * comment resolutions --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9d4bfe3932
commit
7b6451b68e
@@ -36,6 +36,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM.
|
||||
* Added support for f32 to FMHA (fwd/bwd).
|
||||
* Added tensor-wise quantization for CK_TILE GEMM.
|
||||
* Added pooling kernel in CK_TILE
|
||||
|
||||
### Optimized
|
||||
|
||||
|
||||
8
example/ck_tile/36_pooling/CMakeLists.txt
Normal file
8
example/ck_tile/36_pooling/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
set(EXAMPLE_POOL_3D "tile_example_pool3d")
|
||||
message(DEBUG "adding example ${EXAMPLE_POOL_3D}")
|
||||
|
||||
add_executable(${EXAMPLE_POOL_3D} EXCLUDE_FROM_ALL pool3d.cpp)
|
||||
target_include_directories(${EXAMPLE_POOL_3D} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
target_compile_options(${EXAMPLE_POOL_3D} PRIVATE ${EXAMPLE_POOL_COMPILE_OPTIONS})
|
||||
|
||||
42
example/ck_tile/36_pooling/README.md
Normal file
42
example/ck_tile/36_pooling/README.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Pooling Operator
|
||||
|
||||
This folder contains example for the pooling operator using ck_tile tile-programming implementation. Currently the pooling kernel only supports 2D and 3D pooling.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The 3D pooling example
|
||||
make tile_example_pool3d -j`nproc`
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_pool3d`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-N batch size (default:2)
|
||||
-D depth dimension (default:30)
|
||||
-H height dimension (default:30)
|
||||
-W width dimension (default:30)
|
||||
-C channel dimension (default:32)
|
||||
-Z pooling window depth (default:2)
|
||||
-Y pooling window height (default:2)
|
||||
-X pooling window width (default:2)
|
||||
-Sz window stride depth (default:2)
|
||||
-Sy window stride height (default:2)
|
||||
-Sx window stride width (default:2)
|
||||
-Dz window dilation depth (default:1)
|
||||
-Dy window dilation height (default:1)
|
||||
-Dx window dilation width (default:1)
|
||||
-LeftPz left padding depth (default:1)
|
||||
-LeftPy left padding height (default:1)
|
||||
-LeftPx left padding width (default:1)
|
||||
-RightPz right padding depth (default:1)
|
||||
-RightPy right padding height (default:1)
|
||||
-RightPx right padding width (default:1)
|
||||
-v 0: No validation, 1: CPU validation (default:1)
|
||||
-warmup number of iterations before benchmark (default:0)
|
||||
-repeat number of iterations to benchmark (default:1)
|
||||
```
|
||||
188
example/ck_tile/36_pooling/pool3d.cpp
Normal file
188
example/ck_tile/36_pooling/pool3d.cpp
Normal file
@@ -0,0 +1,188 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/pool.hpp"
|
||||
#include "ck_tile/host/reference/reference_pool.hpp"
|
||||
#include <cstring>
|
||||
|
||||
// Parse command-line arguments for 3D pooling example
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("N", "2", "N dimension")
|
||||
.insert("H", "30", "H dimension")
|
||||
.insert("W", "30", "W dimension")
|
||||
.insert("C", "32", "C dimension")
|
||||
.insert("D", "30", "D dimension")
|
||||
.insert("Z", "2", "Z dimension")
|
||||
.insert("Y", "2", "Y dimension")
|
||||
.insert("X", "2", "X dimension")
|
||||
.insert("Sz", "2", "window stride d")
|
||||
.insert("Sy", "2", "window stride h")
|
||||
.insert("Sx", "2", "window stride w")
|
||||
.insert("Dz", "1", "window dilation d")
|
||||
.insert("Dy", "1", "window dilation h")
|
||||
.insert("Dx", "1", "window dilation w")
|
||||
.insert("LeftPz", "1", "left padding d")
|
||||
.insert("LeftPy", "1", "left padding h")
|
||||
.insert("LeftPx", "1", "left padding w")
|
||||
.insert("RightPz", "1", "right padding d")
|
||||
.insert("RightPy", "1", "right padding h")
|
||||
.insert("RightPx", "1", "right padding w")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("warmup", "0", "cold iter")
|
||||
.insert("repeat", "1", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InDataType, typename OutDataType, typename ComputeDataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
|
||||
const ck_tile::index_t N = arg_parser.get_int("N");
|
||||
const ck_tile::index_t H = arg_parser.get_int("H");
|
||||
const ck_tile::index_t W = arg_parser.get_int("W");
|
||||
const ck_tile::index_t C = arg_parser.get_int("C");
|
||||
const ck_tile::index_t D = arg_parser.get_int("D");
|
||||
|
||||
const ck_tile::index_t Z = arg_parser.get_int("Z");
|
||||
const ck_tile::index_t Y = arg_parser.get_int("Y");
|
||||
const ck_tile::index_t X = arg_parser.get_int("X");
|
||||
|
||||
const ck_tile::index_t Sz = arg_parser.get_int("Sz");
|
||||
const ck_tile::index_t Sy = arg_parser.get_int("Sy");
|
||||
const ck_tile::index_t Sx = arg_parser.get_int("Sx");
|
||||
|
||||
const ck_tile::index_t Dz = arg_parser.get_int("Dz");
|
||||
const ck_tile::index_t Dy = arg_parser.get_int("Dy");
|
||||
const ck_tile::index_t Dx = arg_parser.get_int("Dx");
|
||||
|
||||
const ck_tile::index_t LeftPz = arg_parser.get_int("LeftPz");
|
||||
const ck_tile::index_t LeftPy = arg_parser.get_int("LeftPy");
|
||||
const ck_tile::index_t LeftPx = arg_parser.get_int("LeftPx");
|
||||
const ck_tile::index_t RightPz = arg_parser.get_int("RightPz");
|
||||
const ck_tile::index_t RightPy = arg_parser.get_int("RightPy");
|
||||
const ck_tile::index_t RightPx = arg_parser.get_int("RightPx");
|
||||
|
||||
const ck_tile::index_t Zs = (Z - 1) * Dz + 1;
|
||||
const ck_tile::index_t Ys = (Y - 1) * Dy + 1;
|
||||
const ck_tile::index_t Xs = (X - 1) * Dx + 1;
|
||||
|
||||
const ck_tile::index_t Do = (D + LeftPz + RightPz - Zs) / Sz + 1;
|
||||
const ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
|
||||
const ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
|
||||
|
||||
printf("Input parameters:\n");
|
||||
printf("N: %d, D: %d, H: %d, W: %d, C: %d\n", N, D, H, W, C);
|
||||
printf("Window Z: %d, Y: %d, X: %d, Stride Z: %d, Y: %d, X: %d\n", Z, Y, X, Sz, Sy, Sx);
|
||||
printf("Output Do: %d, Ho: %d, Wo: %d\n", Do, Ho, Wo);
|
||||
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
// Shapes / strides / parameters (NDHWC)
|
||||
const auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
|
||||
const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
|
||||
const auto input_strides = ck_tile::make_tuple(D * H * W * C, H * W * C, W * C, C, 1);
|
||||
const auto output_strides = ck_tile::make_tuple(Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1);
|
||||
const auto window_spatial_lengths = ck_tile::make_tuple(Z, Y, X);
|
||||
const auto window_strides = ck_tile::make_tuple(Sz, Sy, Sx);
|
||||
const auto window_dilations = ck_tile::make_tuple(Dz, Dy, Dx);
|
||||
const auto input_left_pads = ck_tile::make_tuple(LeftPz, LeftPy, LeftPx);
|
||||
const auto input_right_pads = ck_tile::make_tuple(RightPz, RightPy, RightPx);
|
||||
|
||||
ck_tile::HostTensor<InDataType> in({N, D, H, W, C}, {D * H * W * C, H * W * C, W * C, C, 1});
|
||||
ck_tile::HostTensor<OutDataType> out({N, Do, Ho, Wo, C},
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
ck_tile::HostTensor<OutDataType> out_ref({N, Do, Ho, Wo, C},
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(in);
|
||||
|
||||
ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes());
|
||||
|
||||
in_buf.ToDevice(in.data());
|
||||
|
||||
using ReduceOp = ck_tile::ReduceOp::Max;
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
using Shape = ck_tile::PoolShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
ReduceOp,
|
||||
false,
|
||||
false,
|
||||
Shape>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
|
||||
static_cast<InDataType*>(in_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
auto kernel_args = Kernel::MakeKernelArgs(host_args);
|
||||
|
||||
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
// Validate kernel can handle the given configuration
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
{
|
||||
throw std::runtime_error("ERROR: Kernel arguments are not supported! \n");
|
||||
}
|
||||
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(InDataType) * N * D * H * W * C + sizeof(OutDataType) * N * Do * Ho * Wo * C;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
|
||||
in, out_ref, kernel_args, ReduceOp{});
|
||||
out_buf.FromDevice(out.mData.data());
|
||||
pass = ck_tile::check_err(out, out_ref);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
@@ -23,6 +23,7 @@ add_subdirectory(20_grouped_convolution)
|
||||
add_subdirectory(21_elementwise)
|
||||
add_subdirectory(22_gemm_multi_abd)
|
||||
add_subdirectory(35_batched_transpose)
|
||||
add_subdirectory(36_pooling)
|
||||
add_subdirectory(38_block_scale_gemm)
|
||||
add_subdirectory(39_copy)
|
||||
add_subdirectory(40_streamk_gemm)
|
||||
|
||||
147
include/ck_tile/host/reference/reference_pool.hpp
Normal file
147
include/ck_tile/host/reference/reference_pool.hpp
Normal file
@@ -0,0 +1,147 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<2>{});
|
||||
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<3>{});
|
||||
|
||||
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<1>{});
|
||||
|
||||
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<1>{});
|
||||
|
||||
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<1>{});
|
||||
|
||||
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<1>{});
|
||||
// Right padding is handled implicitly by bounds checking
|
||||
|
||||
auto f = [&](auto n, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
for(ck_tile::index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
// Calculate input height index with stride, dilation, and padding
|
||||
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
|
||||
|
||||
for(ck_tile::index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// Calculate input width index with stride, dilation, and padding
|
||||
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
|
||||
|
||||
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
|
||||
{
|
||||
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
}
|
||||
|
||||
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
make_ParallelTensorFunctor(f, N, Ho, Wo, C)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t D = kargs.input_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<2>{});
|
||||
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<3>{});
|
||||
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<4>{});
|
||||
|
||||
const ck_tile::index_t Do = kargs.output_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<2>{});
|
||||
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<3>{});
|
||||
|
||||
const ck_tile::index_t Z = kargs.window_lengths.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t Sz = kargs.window_strides.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t Dz = kargs.window_dilations.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t LeftPz = kargs.input_left_pads.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<2>{});
|
||||
// Right padding is handled implicitly by bounds checking
|
||||
|
||||
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
for(ck_tile::index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
// Calculate input depth index with stride, dilation, and padding
|
||||
ck_tile::index_t di = do_ * Sz + z * Dz - LeftPz;
|
||||
|
||||
for(ck_tile::index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
// Calculate input height index with stride, dilation, and padding
|
||||
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
|
||||
|
||||
for(ck_tile::index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// Calculate input width index with stride, dilation, and padding
|
||||
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
|
||||
|
||||
if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
|
||||
{
|
||||
const ComputeDataType v_in =
|
||||
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
11
include/ck_tile/ops/pool.hpp
Normal file
11
include/ck_tile/ops/pool.hpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
|
||||
#include "ck_tile/ops/pooling/pipeline/pool_problem.hpp"
|
||||
#include "ck_tile/ops/pooling/pipeline/pool_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
496
include/ck_tile/ops/pooling/kernel/pool_kernel.hpp
Normal file
496
include/ck_tile/ops/pooling/kernel/pool_kernel.hpp
Normal file
@@ -0,0 +1,496 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/pooling/pipeline/pool_default_policy.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Host arguments for pooling operations
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
struct PoolHostArgs
|
||||
{
|
||||
|
||||
CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
|
||||
void* output_ptr_,
|
||||
TensorShape input_shape_,
|
||||
TensorShape output_shape_,
|
||||
TensorShape input_strides_,
|
||||
TensorShape output_strides_,
|
||||
WindowShape window_lengths_,
|
||||
WindowShape window_strides_,
|
||||
WindowShape window_dilations_,
|
||||
WindowShape input_left_pads_,
|
||||
WindowShape input_right_pads_)
|
||||
: input_ptr(input_ptr_),
|
||||
output_ptr(output_ptr_),
|
||||
input_shape(input_shape_),
|
||||
output_shape(output_shape_),
|
||||
input_strides(input_strides_),
|
||||
output_strides(output_strides_),
|
||||
window_lengths(window_lengths_),
|
||||
window_strides(window_strides_),
|
||||
window_dilations(window_dilations_),
|
||||
input_left_pads(input_left_pads_),
|
||||
input_right_pads(input_right_pads_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
TensorShape input_strides;
|
||||
TensorShape output_strides;
|
||||
WindowShape window_lengths;
|
||||
WindowShape window_strides;
|
||||
WindowShape window_dilations;
|
||||
WindowShape input_left_pads;
|
||||
WindowShape input_right_pads;
|
||||
};
|
||||
|
||||
/// @brief Kernel arguments for pooling operations
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
struct PoolKernelArgs
|
||||
{
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
TensorShape input_strides;
|
||||
TensorShape output_strides;
|
||||
WindowShape window_lengths;
|
||||
WindowShape window_strides;
|
||||
WindowShape window_dilations;
|
||||
WindowShape input_left_pads;
|
||||
WindowShape input_right_pads;
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = PoolDefaultPolicy>
|
||||
struct PoolKernel
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using InDataType = ck_tile::remove_cvref_t<typename Problem::InDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using OutDataType = ck_tile::remove_cvref_t<typename Problem::OutDataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? kBlockSize / 2 : kBlockSize;
|
||||
}
|
||||
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
static CK_TILE_DEVICE auto MakeTensorView2D(PoolKernelArgs<TensorShape, WindowShape> kargs)
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Compile-time validation for 2D pooling
|
||||
static_assert(TensorShape::size() == 4, "2D pooling requires 4D input tensor (N,H,W,C)");
|
||||
static_assert(WindowShape::size() == 2, "2D pooling requires 2D window shape (Y,X)");
|
||||
|
||||
// Extract dimension values
|
||||
const index_t N = kargs.input_shape.at(number<0>{});
|
||||
const index_t H = kargs.input_shape.at(number<1>{});
|
||||
const index_t W = kargs.input_shape.at(number<2>{});
|
||||
const index_t C = kargs.input_shape.at(number<3>{});
|
||||
|
||||
const index_t No = kargs.output_shape.at(number<0>{});
|
||||
const index_t Ho = kargs.output_shape.at(number<1>{});
|
||||
const index_t Wo = kargs.output_shape.at(number<2>{});
|
||||
const index_t Co = kargs.output_shape.at(number<3>{});
|
||||
|
||||
const index_t Y = kargs.window_lengths.at(number<0>{});
|
||||
const index_t X = kargs.window_lengths.at(number<1>{});
|
||||
|
||||
const index_t WindowStrideH = kargs.window_strides.at(number<0>{});
|
||||
const index_t WindowStrideW = kargs.window_strides.at(number<1>{});
|
||||
|
||||
const index_t WindowDilationH = kargs.window_dilations.at(number<0>{});
|
||||
const index_t WindowDilationW = kargs.window_dilations.at(number<1>{});
|
||||
|
||||
const index_t InLeftPadH = kargs.input_left_pads.at(number<0>{});
|
||||
const index_t InLeftPadW = kargs.input_left_pads.at(number<1>{});
|
||||
|
||||
const index_t InRightPadH = kargs.input_right_pads.at(number<0>{});
|
||||
const index_t InRightPadW = kargs.input_right_pads.at(number<1>{});
|
||||
|
||||
const index_t MRaw = N * Ho * Wo * C;
|
||||
const index_t KRaw = Y * X;
|
||||
const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
|
||||
const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
|
||||
|
||||
auto reduce_op = typename Problem::ReduceOp{};
|
||||
|
||||
// Create input descriptor with all transformations
|
||||
auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
|
||||
|
||||
// Apply spatial padding to input descriptor
|
||||
const auto padded_in_desc = transform_tensor_descriptor(
|
||||
in_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(H, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(W, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
// Create sliding windows by embedding pooling windows into descriptor
|
||||
const auto embed_in_desc = transform_tensor_descriptor(
|
||||
padded_in_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
|
||||
// Reshape into 2D matrix: output positions (M) x pooling window elements (K)
|
||||
const auto merged_embed_in_desc =
|
||||
transform_tensor_descriptor(embed_in_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)),
|
||||
make_merge_transform(make_tuple(Y, X))),
|
||||
make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_desc_padded = transform_tensor_descriptor(
|
||||
merged_embed_in_desc,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Create output descriptor with transformations
|
||||
auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
|
||||
|
||||
const auto merged_out_desc = transform_tensor_descriptor(
|
||||
out_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(No, Ho, Wo, Co))),
|
||||
make_tuple(sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
const auto out_desc_padded =
|
||||
transform_tensor_descriptor(merged_out_desc,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
// Now create buffer views and tensor views with the fully transformed descriptors
|
||||
const InDataType in_identity =
|
||||
type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
const OutDataType out_identity =
|
||||
type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
auto in_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<const InDataType*>(kargs.input_ptr),
|
||||
in_desc.get_element_space_size(),
|
||||
in_identity);
|
||||
const auto in_tensor_padded =
|
||||
tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
|
||||
in_desc_padded};
|
||||
|
||||
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(kargs.output_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
out_identity);
|
||||
const auto out_tensor_padded =
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
}
|
||||
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
static CK_TILE_DEVICE auto MakeTensorView3D(PoolKernelArgs<TensorShape, WindowShape> kargs)
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Compile-time validation for 3D pooling
|
||||
static_assert(TensorShape::size() == 5, "3D pooling requires 5D input tensor (N,D,H,W,C)");
|
||||
static_assert(WindowShape::size() == 3, "3D pooling requires 3D window shape (Z,Y,X)");
|
||||
|
||||
// Extract dimension values
|
||||
const index_t N = kargs.input_shape.at(number<0>{});
|
||||
const index_t D = kargs.input_shape.at(number<1>{});
|
||||
const index_t H = kargs.input_shape.at(number<2>{});
|
||||
const index_t W = kargs.input_shape.at(number<3>{});
|
||||
const index_t C = kargs.input_shape.at(number<4>{});
|
||||
|
||||
const index_t No = kargs.output_shape.at(number<0>{});
|
||||
const index_t Do = kargs.output_shape.at(number<1>{});
|
||||
const index_t Ho = kargs.output_shape.at(number<2>{});
|
||||
const index_t Wo = kargs.output_shape.at(number<3>{});
|
||||
const index_t Co = kargs.output_shape.at(number<4>{});
|
||||
|
||||
const index_t Z = kargs.window_lengths.at(number<0>{});
|
||||
const index_t Y = kargs.window_lengths.at(number<1>{});
|
||||
const index_t X = kargs.window_lengths.at(number<2>{});
|
||||
|
||||
const index_t WindowStrideD = kargs.window_strides.at(number<0>{});
|
||||
const index_t WindowStrideH = kargs.window_strides.at(number<1>{});
|
||||
const index_t WindowStrideW = kargs.window_strides.at(number<2>{});
|
||||
|
||||
const index_t WindowDilationD = kargs.window_dilations.at(number<0>{});
|
||||
const index_t WindowDilationH = kargs.window_dilations.at(number<1>{});
|
||||
const index_t WindowDilationW = kargs.window_dilations.at(number<2>{});
|
||||
|
||||
const index_t InLeftPadD = kargs.input_left_pads.at(number<0>{});
|
||||
const index_t InLeftPadH = kargs.input_left_pads.at(number<1>{});
|
||||
const index_t InLeftPadW = kargs.input_left_pads.at(number<2>{});
|
||||
|
||||
const index_t InRightPadD = kargs.input_right_pads.at(number<0>{});
|
||||
const index_t InRightPadH = kargs.input_right_pads.at(number<1>{});
|
||||
const index_t InRightPadW = kargs.input_right_pads.at(number<2>{});
|
||||
|
||||
const index_t MRaw = N * Do * Ho * Wo * C;
|
||||
const index_t KRaw = Z * Y * X;
|
||||
const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
|
||||
const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
|
||||
|
||||
auto reduce_op = typename Problem::ReduceOp{};
|
||||
|
||||
// Create input descriptor with all transformations
|
||||
auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
|
||||
|
||||
// Apply spatial padding to input descriptor (all 3D dimensions)
|
||||
const auto padded_in_desc = transform_tensor_descriptor(
|
||||
in_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(D, InLeftPadD, InRightPadD),
|
||||
make_pad_transform(H, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(W, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
// Create 3D sliding windows by embedding pooling windows into descriptor
|
||||
const auto embed_in_desc = transform_tensor_descriptor(
|
||||
padded_in_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
// Reshape into 2D matrix: output positions (M) x pooling window elements (K)
|
||||
const auto merged_embed_in_desc = transform_tensor_descriptor(
|
||||
embed_in_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)),
|
||||
make_merge_transform(make_tuple(Z, Y, X))),
|
||||
make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_desc_padded = transform_tensor_descriptor(
|
||||
merged_embed_in_desc,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Create output descriptor with transformations
|
||||
auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
|
||||
|
||||
const auto merged_out_desc = transform_tensor_descriptor(
|
||||
out_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(No, Do, Ho, Wo, Co))),
|
||||
make_tuple(sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
const auto out_desc_padded =
|
||||
transform_tensor_descriptor(merged_out_desc,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
// Now create buffer views and tensor views with the fully transformed descriptors
|
||||
const InDataType in_identity =
|
||||
type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
const OutDataType out_identity =
|
||||
type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
auto in_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<const InDataType*>(kargs.input_ptr),
|
||||
in_desc.get_element_space_size(),
|
||||
in_identity);
|
||||
const auto in_tensor_padded =
|
||||
tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
|
||||
in_desc_padded};
|
||||
|
||||
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(kargs.output_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
out_identity);
|
||||
const auto out_tensor_padded =
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
CK_TILE_DEVICE void operator()(PoolKernelArgs<TensorShape, WindowShape> kargs) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Compile-time validation for supported window dimensions
|
||||
static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
|
||||
"Only 2D and 3D pooling operations are supported");
|
||||
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
// Get tensors based on dimensionality
|
||||
auto [in_tensor_padded, out_tensor_padded] = [&]() {
|
||||
if constexpr(WindowShape::size() == 2)
|
||||
return MakeTensorView2D(kargs);
|
||||
else if constexpr(WindowShape::size() == 3)
|
||||
return MakeTensorView3D(kargs);
|
||||
else
|
||||
static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
|
||||
"Unsupported WindowShape rank: only 2D or 3D pooling is supported");
|
||||
}();
|
||||
|
||||
auto reduce_op = typename Problem::ReduceOp{};
|
||||
|
||||
auto x_window = make_tile_window(in_tensor_padded,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto y_window = make_tile_window(out_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
|
||||
|
||||
__shared__ char smem[Policy::template GetSmemSize<Problem>()];
|
||||
|
||||
const auto reduce_len =
|
||||
in_tensor_padded.get_tensor_descriptor().get_lengths().at(number<1>{});
|
||||
index_t num_k_tiles =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(reduce_len, S::Block_N));
|
||||
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp = Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorTile = decltype(load_tile(x_window));
|
||||
auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
|
||||
set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
}
|
||||
|
||||
/// @brief Validates if the given arguments are supported by the pooling kernel.
|
||||
///
|
||||
/// @param kargs The pooling kernel arguments containing all necessary parameters.
|
||||
///
|
||||
/// @return true if the arguments are supported, false otherwise.
|
||||
///
|
||||
/// @note Requirements:
|
||||
/// - Last dimension (C) must be contiguous (stride = 1) for vectorized access
|
||||
/// - Window dimensions must be supported (2D or 3D)
|
||||
/// - All dimension sizes must be consistent between input and output
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
CK_TILE_HOST static bool IsSupportedArgument(PoolKernelArgs<TensorShape, WindowShape> kargs)
|
||||
{
|
||||
constexpr index_t InputRank = TensorShape::size();
|
||||
constexpr index_t OutputRank = TensorShape::size(); // Same as input rank
|
||||
constexpr index_t WindowRank = WindowShape::size();
|
||||
|
||||
// Validate window dimensions (only 2D and 3D supported)
|
||||
if constexpr(WindowRank != 2 && WindowRank != 3)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Only 2D and 3D pooling are supported!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate that input rank matches expected rank for window dimensions
|
||||
if constexpr((WindowRank == 2 && InputRank != 4) || (WindowRank == 3 && InputRank != 5))
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Input tensor rank doesn't match window dimensions!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that channel dimension (last dimension) is contiguous for both input and output
|
||||
if(kargs.input_strides.at(number<InputRank - 1>{}) != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Input tensor's channel dimension must have stride 1!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(kargs.output_strides.at(number<OutputRank - 1>{}) != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Output tensor's channel dimension must have stride 1!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// @param kargs The pooling kernel arguments
|
||||
/// @return The calculated grid size
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
CK_TILE_HOST static constexpr index_t
|
||||
CalculateGridSize(PoolKernelArgs<TensorShape, WindowShape> kargs)
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Calculate total output elements (M dimension)
|
||||
index_t M = 1;
|
||||
static_for<0, TensorShape::size(), 1>{}([&](auto i) { M *= kargs.output_shape.at(i); });
|
||||
|
||||
// Calculate grid size: ceil(M / Block_M)
|
||||
return (M + S::Block_M - 1) / S::Block_M;
|
||||
}
|
||||
|
||||
/// @brief Create kernel arguments from host arguments
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
MakeKernelArgs(PoolHostArgs<TensorShape, WindowShape>& host_args)
|
||||
{
|
||||
return PoolKernelArgs<TensorShape, WindowShape>{host_args.input_ptr,
|
||||
host_args.output_ptr,
|
||||
host_args.input_shape,
|
||||
host_args.output_shape,
|
||||
host_args.input_strides,
|
||||
host_args.output_strides,
|
||||
host_args.window_lengths,
|
||||
host_args.window_strides,
|
||||
host_args.window_dilations,
|
||||
host_args.input_left_pads,
|
||||
host_args.input_right_pads};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
80
include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp
Normal file
80
include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp
Normal file
@@ -0,0 +1,80 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct PoolDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<
|
||||
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>,
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2d<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile =
|
||||
decltype(make_static_distributed_tensor<typename Problem::InDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
|
||||
|
||||
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
33
include/ck_tile/ops/pooling/pipeline/pool_problem.hpp
Normal file
33
include/ck_tile/ops/pooling/pipeline/pool_problem.hpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType_,
|
||||
typename OutDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename IndexDataType_,
|
||||
typename ReduceOp_,
|
||||
bool OutputIndex_,
|
||||
bool PropagateNan_,
|
||||
typename BlockShape_>
|
||||
struct PoolProblem
|
||||
{
|
||||
using InDataType = remove_cvref_t<InDataType_>;
|
||||
using OutDataType = remove_cvref_t<OutDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using IndexDataType = remove_cvref_t<IndexDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ReduceOp = ReduceOp_;
|
||||
using OutputIndex = bool_constant<OutputIndex_>;
|
||||
using PropagateNan = bool_constant<PropagateNan_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
57
include/ck_tile/ops/pooling/pipeline/pool_shape.hpp
Normal file
57
include/ck_tile/ops/pooling/pipeline/pool_shape.hpp
Normal file
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockWarps, // num warps along seq<M, N>
|
||||
typename BlockTile, // block size, seq<M, N>
|
||||
typename WarpTile, // warp size, seq<M, N>
|
||||
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
|
||||
struct PoolShape
|
||||
{
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
|
||||
|
||||
static_assert(Warp_M % ThreadTile_M == 0, "Warp_M must be divisible by ThreadTile_M");
|
||||
static_assert(Warp_N % ThreadTile_N == 0, "Warp_N must be divisible by ThreadTile_N");
|
||||
static_assert((Warp_M * Warp_N / ThreadTile_M / ThreadTile_N) % ck_tile::get_warp_size() == 0,
|
||||
"Warp_M * Warp_N / ThreadTile_M / ThreadTile_N must be a multiple of warp size");
|
||||
|
||||
// Scale factor to account for warp size
|
||||
// WarpSizeScaleFactor = warp tile/ thread tile / warp size
|
||||
static constexpr index_t WarpSizeScaleFactor =
|
||||
Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
|
||||
|
||||
static constexpr index_t WarpSizeScaleFactor_M =
|
||||
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? WarpSizeScaleFactor : 1;
|
||||
static constexpr index_t WarpSizeScaleFactor_N =
|
||||
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : WarpSizeScaleFactor;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / WarpSizeScaleFactor_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / WarpSizeScaleFactor_N;
|
||||
|
||||
static_assert((Block_M * WarpSizeScaleFactor_M) % (WarpPerBlock_M * Warp_M) == 0,
|
||||
"Block_M * WarpSizeScaleFactor_M must be divisible by WarpPerBlock_M * Warp_M");
|
||||
static_assert((Block_N * WarpSizeScaleFactor_N) % (WarpPerBlock_N * Warp_N) == 0,
|
||||
"Block_N * WarpSizeScaleFactor_N must be divisible by WarpPerBlock_N * Warp_N");
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M * WarpSizeScaleFactor_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N * WarpSizeScaleFactor_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -31,3 +31,4 @@ add_subdirectory(epilogue)
|
||||
add_subdirectory(atomic_add_op)
|
||||
add_subdirectory(fmha)
|
||||
add_subdirectory(gemm_tile_engine)
|
||||
add_subdirectory(pooling)
|
||||
|
||||
3
test/ck_tile/pooling/CMakeLists.txt
Normal file
3
test/ck_tile/pooling/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_pooling test_pooling.cpp)
|
||||
endif()
|
||||
249
test/ck_tile/pooling/test_pooling.cpp
Normal file
249
test/ck_tile/pooling/test_pooling.cpp
Normal file
@@ -0,0 +1,249 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <tuple>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/pool.hpp"
|
||||
#include "ck_tile/host/reference/reference_pool.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTilePooling : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using InDataType = std::tuple_element_t<0, Tuple>;
|
||||
using OutDataType = std::tuple_element_t<1, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<2, Tuple>;
|
||||
using ReduceOpType = std::tuple_element_t<3, Tuple>;
|
||||
using BlockWarps_ = std::tuple_element_t<4, Tuple>;
|
||||
using BlockTile_ = std::tuple_element_t<5, Tuple>;
|
||||
using WarpTile_ = std::tuple_element_t<6, Tuple>;
|
||||
using ThreadTile_ = std::tuple_element_t<7, Tuple>;
|
||||
|
||||
using TestPoolShape = ck_tile::PoolShape<BlockWarps_, BlockTile_, WarpTile_, ThreadTile_>;
|
||||
|
||||
// 3D pooling configuration
|
||||
struct Config3D
|
||||
{
|
||||
ck_tile::index_t N, D, H, W, C;
|
||||
ck_tile::index_t Z, Y, X;
|
||||
ck_tile::index_t Sz, Sy, Sx;
|
||||
ck_tile::index_t Dz, Dy, Dx;
|
||||
ck_tile::index_t LeftPz, LeftPy, LeftPx;
|
||||
ck_tile::index_t RightPz, RightPy, RightPx;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
bool RunPool3D(const Config3D& config)
|
||||
{
|
||||
std::cout << "Testing 3D: " << config.name << " ... ";
|
||||
|
||||
const ck_tile::index_t Zs = (config.Z - 1) * config.Dz + 1;
|
||||
const ck_tile::index_t Ys = (config.Y - 1) * config.Dy + 1;
|
||||
const ck_tile::index_t Xs = (config.X - 1) * config.Dx + 1;
|
||||
const ck_tile::index_t Do =
|
||||
(config.D + config.LeftPz + config.RightPz - Zs) / config.Sz + 1;
|
||||
const ck_tile::index_t Ho =
|
||||
(config.H + config.LeftPy + config.RightPy - Ys) / config.Sy + 1;
|
||||
const ck_tile::index_t Wo =
|
||||
(config.W + config.LeftPx + config.RightPx - Xs) / config.Sx + 1;
|
||||
|
||||
const auto input_shape =
|
||||
ck_tile::make_tuple(config.N, config.D, config.H, config.W, config.C);
|
||||
const auto output_shape = ck_tile::make_tuple(config.N, Do, Ho, Wo, config.C);
|
||||
const auto input_strides = ck_tile::make_tuple(config.D * config.H * config.W * config.C,
|
||||
config.H * config.W * config.C,
|
||||
config.W * config.C,
|
||||
config.C,
|
||||
1);
|
||||
const auto output_strides = ck_tile::make_tuple(
|
||||
Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1);
|
||||
const auto window_spatial_lengths = ck_tile::make_tuple(config.Z, config.Y, config.X);
|
||||
const auto window_strides = ck_tile::make_tuple(config.Sz, config.Sy, config.Sx);
|
||||
const auto window_dilations = ck_tile::make_tuple(config.Dz, config.Dy, config.Dx);
|
||||
const auto input_left_pads =
|
||||
ck_tile::make_tuple(config.LeftPz, config.LeftPy, config.LeftPx);
|
||||
const auto input_right_pads =
|
||||
ck_tile::make_tuple(config.RightPz, config.RightPy, config.RightPx);
|
||||
|
||||
ck_tile::HostTensor<InDataType> h_in({config.N, config.D, config.H, config.W, config.C},
|
||||
{config.D * config.H * config.W * config.C,
|
||||
config.H * config.W * config.C,
|
||||
config.W * config.C,
|
||||
config.C,
|
||||
1});
|
||||
ck_tile::HostTensor<OutDataType> h_out(
|
||||
{config.N, Do, Ho, Wo, config.C},
|
||||
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
|
||||
ck_tile::HostTensor<OutDataType> h_out_ref(
|
||||
{config.N, Do, Ho, Wo, config.C},
|
||||
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
|
||||
h_out.SetZero();
|
||||
h_out_ref.SetZero();
|
||||
|
||||
ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes());
|
||||
|
||||
d_in_mem.ToDevice(h_in.data());
|
||||
d_out_mem.ToDevice(h_out.data());
|
||||
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
ReduceOpType,
|
||||
false,
|
||||
false,
|
||||
TestPoolShape>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
auto host_args =
|
||||
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
|
||||
static_cast<InDataType*>(d_in_mem.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(d_out_mem.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
auto kernel_args = Kernel::MakeKernelArgs(host_args);
|
||||
|
||||
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Run kernel
|
||||
ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, false, 0},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
|
||||
|
||||
// Run reference implementation
|
||||
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
|
||||
h_in, h_out_ref, kernel_args, ReduceOpType{});
|
||||
|
||||
d_out_mem.FromDevice(h_out.data());
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(h_out, h_out_ref);
|
||||
std::cout << (pass ? "PASS" : "FAIL") << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
};
|
||||
|
||||
using Shape1_BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using Shape1_BlockTile = ck_tile::sequence<128, 128>;
|
||||
using Shape1_WarpTile = ck_tile::sequence<32, 128>;
|
||||
using Shape1_ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
// Cross-warp configuration
|
||||
using Shape2_BlockWarps = ck_tile::sequence<2, 2>;
|
||||
using Shape2_BlockTile = ck_tile::sequence<2, 1024>;
|
||||
using Shape2_WarpTile = ck_tile::sequence<1, 512>;
|
||||
using Shape2_ThreadTile = ck_tile::sequence<1, 8>;
|
||||
|
||||
// Test configurations for different data types and operations
|
||||
using TestConfig_F32_Max = std::tuple<float,
|
||||
float,
|
||||
float,
|
||||
ck_tile::ReduceOp::Max,
|
||||
Shape1_BlockWarps,
|
||||
Shape1_BlockTile,
|
||||
Shape1_WarpTile,
|
||||
Shape1_ThreadTile>;
|
||||
|
||||
using TestConfig_F16_Max = std::tuple<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
ck_tile::ReduceOp::Max,
|
||||
Shape1_BlockWarps,
|
||||
Shape1_BlockTile,
|
||||
Shape1_WarpTile,
|
||||
Shape1_ThreadTile>;
|
||||
|
||||
using TestConfig_F32_CrossWarp = std::tuple<float,
|
||||
float,
|
||||
float,
|
||||
ck_tile::ReduceOp::Max,
|
||||
Shape2_BlockWarps,
|
||||
Shape2_BlockTile,
|
||||
Shape2_WarpTile,
|
||||
Shape2_ThreadTile>;
|
||||
|
||||
using TestTypes =
|
||||
::testing::Types<TestConfig_F32_Max, TestConfig_F16_Max, TestConfig_F32_CrossWarp>;
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTilePooling, TestTypes);
|
||||
|
||||
TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2)
|
||||
{
|
||||
typename TestFixture::Config3D config = {1, // N - batch size
|
||||
4, // D - depth dimension
|
||||
4, // H - height dimension
|
||||
4, // W - width dimension
|
||||
32, // C - channel dimension
|
||||
2, // Z - pooling window depth
|
||||
2, // Y - pooling window height
|
||||
2, // X - pooling window width
|
||||
2, // Sz - window stride depth
|
||||
2, // Sy - window stride height
|
||||
2, // Sx - window stride width
|
||||
1, // Dz - window dilation depth
|
||||
1, // Dy - window dilation height
|
||||
1, // Dx - window dilation width
|
||||
0, // LeftPz - left padding depth
|
||||
0, // LeftPy - left padding height
|
||||
0, // LeftPx - left padding width
|
||||
0, // RightPz - right padding depth
|
||||
0, // RightPy - right padding height
|
||||
0, // RightPx - right padding width
|
||||
"2x2x2 pooling"};
|
||||
bool pass = this->RunPool3D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTilePooling, Pool3D_3x3x3)
|
||||
{
|
||||
typename TestFixture::Config3D config = {2, // N - batch size
|
||||
16, // D - depth dimension
|
||||
16, // H - height dimension
|
||||
16, // W - width dimension
|
||||
128, // C - channel dimension
|
||||
3, // Z - pooling window depth
|
||||
3, // Y - pooling window height
|
||||
3, // X - pooling window width
|
||||
2, // Sz - window stride depth
|
||||
2, // Sy - window stride height
|
||||
2, // Sx - window stride width
|
||||
1, // Dz - window dilation depth
|
||||
1, // Dy - window dilation height
|
||||
1, // Dx - window dilation width
|
||||
1, // LeftPz - left padding depth
|
||||
1, // LeftPy - left padding height
|
||||
1, // LeftPx - left padding width
|
||||
1, // RightPz - right padding depth
|
||||
1, // RightPy - right padding height
|
||||
1, // RightPx - right padding width
|
||||
"3x3x3 pooling"};
|
||||
bool pass = this->RunPool3D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
Reference in New Issue
Block a user