Files
composable_kernel/example/ck_tile/36_pooling/pool3d.cpp
Yashvardhan Agarwal 7b6451b68e [CK_TILE] Pooling FWD (Lwpck 3683) (#2956)
* Pooling 2D/3D with refernce

* Tests & cleanup

- added test for ppoling
- cleanup
- removed 2d example

* Comment resolution

- README added
- example target name rectified
- appropriate arg description and comments added

* clang-format

* appropriate blocksize calc

* modifications for future indexing addition

- instead of transforming views we now transform the descriptors, so
that the same descriptor can be re-used for index tensor in the future

* some basic fixes

* comment resolutions

* comment resolutions

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
2025-10-09 16:13:26 +02:00

189 lines
7.7 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "ck_tile/ops/pool.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
#include <cstring>
// Parse command-line arguments for 3D pooling example
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("N", "2", "N dimension")
.insert("H", "30", "H dimension")
.insert("W", "30", "W dimension")
.insert("C", "32", "C dimension")
.insert("D", "30", "D dimension")
.insert("Z", "2", "Z dimension")
.insert("Y", "2", "Y dimension")
.insert("X", "2", "X dimension")
.insert("Sz", "2", "window stride d")
.insert("Sy", "2", "window stride h")
.insert("Sx", "2", "window stride w")
.insert("Dz", "1", "window dilation d")
.insert("Dy", "1", "window dilation h")
.insert("Dx", "1", "window dilation w")
.insert("LeftPz", "1", "left padding d")
.insert("LeftPy", "1", "left padding h")
.insert("LeftPx", "1", "left padding w")
.insert("RightPz", "1", "right padding d")
.insert("RightPy", "1", "right padding h")
.insert("RightPx", "1", "right padding w")
.insert("v", "1", "cpu validation or not")
.insert("warmup", "0", "cold iter")
.insert("repeat", "1", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename InDataType, typename OutDataType, typename ComputeDataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
const ck_tile::index_t N = arg_parser.get_int("N");
const ck_tile::index_t H = arg_parser.get_int("H");
const ck_tile::index_t W = arg_parser.get_int("W");
const ck_tile::index_t C = arg_parser.get_int("C");
const ck_tile::index_t D = arg_parser.get_int("D");
const ck_tile::index_t Z = arg_parser.get_int("Z");
const ck_tile::index_t Y = arg_parser.get_int("Y");
const ck_tile::index_t X = arg_parser.get_int("X");
const ck_tile::index_t Sz = arg_parser.get_int("Sz");
const ck_tile::index_t Sy = arg_parser.get_int("Sy");
const ck_tile::index_t Sx = arg_parser.get_int("Sx");
const ck_tile::index_t Dz = arg_parser.get_int("Dz");
const ck_tile::index_t Dy = arg_parser.get_int("Dy");
const ck_tile::index_t Dx = arg_parser.get_int("Dx");
const ck_tile::index_t LeftPz = arg_parser.get_int("LeftPz");
const ck_tile::index_t LeftPy = arg_parser.get_int("LeftPy");
const ck_tile::index_t LeftPx = arg_parser.get_int("LeftPx");
const ck_tile::index_t RightPz = arg_parser.get_int("RightPz");
const ck_tile::index_t RightPy = arg_parser.get_int("RightPy");
const ck_tile::index_t RightPx = arg_parser.get_int("RightPx");
const ck_tile::index_t Zs = (Z - 1) * Dz + 1;
const ck_tile::index_t Ys = (Y - 1) * Dy + 1;
const ck_tile::index_t Xs = (X - 1) * Dx + 1;
const ck_tile::index_t Do = (D + LeftPz + RightPz - Zs) / Sz + 1;
const ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
const ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
printf("Input parameters:\n");
printf("N: %d, D: %d, H: %d, W: %d, C: %d\n", N, D, H, W, C);
printf("Window Z: %d, Y: %d, X: %d, Stride Z: %d, Y: %d, X: %d\n", Z, Y, X, Sz, Sy, Sx);
printf("Output Do: %d, Ho: %d, Wo: %d\n", Do, Ho, Wo);
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
// Shapes / strides / parameters (NDHWC)
const auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
const auto input_strides = ck_tile::make_tuple(D * H * W * C, H * W * C, W * C, C, 1);
const auto output_strides = ck_tile::make_tuple(Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1);
const auto window_spatial_lengths = ck_tile::make_tuple(Z, Y, X);
const auto window_strides = ck_tile::make_tuple(Sz, Sy, Sx);
const auto window_dilations = ck_tile::make_tuple(Dz, Dy, Dx);
const auto input_left_pads = ck_tile::make_tuple(LeftPz, LeftPy, LeftPx);
const auto input_right_pads = ck_tile::make_tuple(RightPz, RightPy, RightPx);
ck_tile::HostTensor<InDataType> in({N, D, H, W, C}, {D * H * W * C, H * W * C, W * C, C, 1});
ck_tile::HostTensor<OutDataType> out({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<OutDataType> out_ref({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(in);
ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes());
in_buf.ToDevice(in.data());
using ReduceOp = ck_tile::ReduceOp::Max;
using BlockWarps = ck_tile::sequence<4, 1>;
using BlockTile = ck_tile::sequence<128, 128>;
using WarpTile = ck_tile::sequence<32, 128>;
using ThreadTile = ck_tile::sequence<8, 8>;
using Shape = ck_tile::PoolShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
using Problem = ck_tile::PoolProblem<InDataType,
OutDataType,
ComputeDataType,
OutDataType,
ReduceOp,
false,
false,
Shape>;
using Kernel = ck_tile::PoolKernel<Problem>;
constexpr ck_tile::index_t kBlockPerCu = 1;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
static_cast<InDataType*>(in_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_buf.GetDeviceBuffer()),
input_shape,
output_shape,
input_strides,
output_strides,
window_spatial_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
auto kernel_args = Kernel::MakeKernelArgs(host_args);
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
std::cout << "grid size " << kGridSize << std::endl;
// Validate kernel can handle the given configuration
if(!Kernel::IsSupportedArgument(kernel_args))
{
throw std::runtime_error("ERROR: Kernel arguments are not supported! \n");
}
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
std::size_t num_btype =
sizeof(InDataType) * N * D * H * W * C + sizeof(OutDataType) * N * Do * Ho * Wo * C;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(do_validation)
{
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
in, out_ref, kernel_args, ReduceOp{});
out_buf.FromDevice(out.mData.data());
pass = ck_tile::check_err(out, out_ref);
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
return run<ck_tile::half_t, ck_tile::half_t, float>(arg_parser) ? 0 : -2;
}