This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(EXAMPLE_POOL_3D "tile_example_pool3d")
message(DEBUG "adding example ${EXAMPLE_POOL_3D}")
add_executable(${EXAMPLE_POOL_3D} pool3d.cpp)
target_include_directories(${EXAMPLE_POOL_3D} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_POOL_3D} PRIVATE ${EXAMPLE_POOL_COMPILE_OPTIONS})

View File

@@ -0,0 +1,152 @@
# Pooling Operator
This folder contains example for the pooling operator using ck_tile tile-programming implementation. Currently the pooling kernel only supports 2D and 3D pooling.
## Tensor Descriptor Transformations
The pooling kernel transforms the input tensor into 2D format suitable for reduction. This section explains the transformation pipeline for both 2D and 3D pooling operations.
### 3D Pooling Transformations
For 3D pooling, the input tensor has shape `(N, D, H, W, C)` where:
- `N`: batch size
- `D`: depth dimension
- `H`: height dimension
- `W`: width dimension
- `C`: channel dimension
The transformations convert this 5D tensor into a 2D tensor where rows represent output positions (M) and columns represent pooling window elements (K).
```mermaid
graph TD
%% Input Tensor: (N, D, H, W, C)
Input["Input Tensor<br/>(N, D, H, W, C)"]
style Input fill:#e1f5fe
%% Pass-through N dimension
PassN["Pass-through N<br/>(batch size)"]
style PassN fill:#f3e5f5
Input --> PassN
%% Pad spatial dimensions
PadD["Pad D<br/>(depth with left/right padding)"]
style PadD fill:#fff9c4
Input --> PadD
PadH["Pad H<br/>(height with left/right padding)"]
style PadH fill:#fff9c4
Input --> PadH
PadW["Pad W<br/>(width with left/right padding)"]
style PadW fill:#fff9c4
Input --> PadW
%% Pass-through C dimension
PassC["Pass-through C<br/>(channels)"]
style PassC fill:#f3e5f5
Input --> PassC
%% Embed sliding windows
EmbedD["Embed D<br/>window(Z) × output_positions(Dₒ)"]
style EmbedD fill:#fff3e0
PadD --> EmbedD
EmbedH["Embed H<br/>window(Y) × output_positions(Hₒ)"]
style EmbedH fill:#fff3e0
PadH --> EmbedH
EmbedW["Embed W<br/>window(X) × output_positions(Wₒ)"]
style EmbedW fill:#fff3e0
PadW --> EmbedW
%% Merge into 2D matrix
MergeM["Merge M<br/>(N, Dₒ, Hₒ, Wₒ, C)<br/>→ output positions"]
style MergeM fill:#e8f5e9
PassN --> MergeM
EmbedD --> MergeM
EmbedH --> MergeM
EmbedW --> MergeM
PassC --> MergeM
MergeK["Merge K<br/>(Z, Y, X)<br/>→ window elements"]
style MergeK fill:#e8f5e9
EmbedD --> MergeK
EmbedH --> MergeK
EmbedW --> MergeK
%% Final padding for block alignment
PadM["Right-pad M<br/>(for block alignment)"]
style PadM fill:#fff9c4
MergeM --> PadM
PadK["Right-pad K<br/>(for block alignment)"]
style PadK fill:#fff9c4
MergeK --> PadK
%% Result
Result["2D Matrix<br/>(M × K)"]
style Result fill:#c8e6c9
PadM --> Result
PadK --> Result
```
**Transformation Steps:**
1. **Padding**: Apply left and right padding to spatial dimensions (D, H, W) to handle boundary conditions
2. **Sliding Windows**: Use embed transforms to create sliding windows across each spatial dimension, expanding each dimension into (window_size, output_positions)
3. **Reshaping**: Merge all dimensions into a 2D matrix where:
- M dimension = N × Dₒ × Hₒ × Wₒ × C (total output positions)
- K dimension = Z × Y × X (elements per pooling window)
4. **Block Alignment**: Apply right padding to ensure M and K dimensions are aligned to block size
### 2D Pooling Transformations
2D pooling follows the same transformation pipeline but operates on 4D tensors with shape `(N, H, W, C)`. The process is identical except:
- Only H and W dimensions are padded and embedded
- K dimension merges only (Y, X) window elements
- M dimension merges (N, Hₒ, Wₒ, C)
### Output Tensor Transformations
The output tensor transformations are simpler:
- Merge all output dimensions (N, Dₒ/Hₒ, Wₒ, C) into a single M dimension
- Apply right padding for block alignment
- The result is a 1D tensor that maps directly to the M dimension of the computation matrix
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
# The 3D pooling example
make tile_example_pool3d -j`nproc`
```
This will result in an executable `build/bin/tile_example_pool3d`
## example
```
args:
-N batch size (default:2)
-D depth dimension (default:30)
-H height dimension (default:30)
-W width dimension (default:30)
-C channel dimension (default:32)
-Z pooling window depth (default:2)
-Y pooling window height (default:2)
-X pooling window width (default:2)
-Sz window stride depth (default:2)
-Sy window stride height (default:2)
-Sx window stride width (default:2)
-Dz window dilation depth (default:1)
-Dy window dilation height (default:1)
-Dx window dilation width (default:1)
-LeftPz left padding depth (default:1)
-LeftPy left padding height (default:1)
-LeftPx left padding width (default:1)
-RightPz right padding depth (default:1)
-RightPy right padding height (default:1)
-RightPx right padding width (default:1)
-v 0: No validation, 1: CPU validation (default:1)
-warmup number of iterations before benchmark (default:0)
-repeat number of iterations to benchmark (default:1)
```

View File

@@ -0,0 +1,216 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/pooling.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
#include <cstring>
// Parse command-line arguments for 3D pooling example
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("N", "2", "N dimension")
.insert("H", "30", "H dimension")
.insert("W", "30", "W dimension")
.insert("C", "32", "C dimension")
.insert("D", "30", "D dimension")
.insert("Z", "2", "Z dimension")
.insert("Y", "2", "Y dimension")
.insert("X", "2", "X dimension")
.insert("Sz", "2", "window stride d")
.insert("Sy", "2", "window stride h")
.insert("Sx", "2", "window stride w")
.insert("Dz", "1", "window dilation d")
.insert("Dy", "1", "window dilation h")
.insert("Dx", "1", "window dilation w")
.insert("LeftPz", "1", "left padding d")
.insert("LeftPy", "1", "left padding h")
.insert("LeftPx", "1", "left padding w")
.insert("RightPz", "1", "right padding d")
.insert("RightPy", "1", "right padding h")
.insert("RightPx", "1", "right padding w")
.insert("v", "1", "cpu validation or not")
.insert("warmup", "20", "cold iter")
.insert("repeat", "100", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename InDataType,
typename OutDataType,
typename ComputeDataType,
typename IndexDataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
const ck_tile::index_t N = arg_parser.get_int("N");
const ck_tile::index_t H = arg_parser.get_int("H");
const ck_tile::index_t W = arg_parser.get_int("W");
const ck_tile::index_t C = arg_parser.get_int("C");
const ck_tile::index_t D = arg_parser.get_int("D");
const ck_tile::index_t Z = arg_parser.get_int("Z");
const ck_tile::index_t Y = arg_parser.get_int("Y");
const ck_tile::index_t X = arg_parser.get_int("X");
const ck_tile::index_t Sz = arg_parser.get_int("Sz");
const ck_tile::index_t Sy = arg_parser.get_int("Sy");
const ck_tile::index_t Sx = arg_parser.get_int("Sx");
const ck_tile::index_t Dz = arg_parser.get_int("Dz");
const ck_tile::index_t Dy = arg_parser.get_int("Dy");
const ck_tile::index_t Dx = arg_parser.get_int("Dx");
const ck_tile::index_t LeftPz = arg_parser.get_int("LeftPz");
const ck_tile::index_t LeftPy = arg_parser.get_int("LeftPy");
const ck_tile::index_t LeftPx = arg_parser.get_int("LeftPx");
const ck_tile::index_t RightPz = arg_parser.get_int("RightPz");
const ck_tile::index_t RightPy = arg_parser.get_int("RightPy");
const ck_tile::index_t RightPx = arg_parser.get_int("RightPx");
const ck_tile::index_t Zs = (Z - 1) * Dz + 1;
const ck_tile::index_t Ys = (Y - 1) * Dy + 1;
const ck_tile::index_t Xs = (X - 1) * Dx + 1;
const ck_tile::index_t Do = (D + LeftPz + RightPz - Zs) / Sz + 1;
const ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
const ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
printf("Input parameters:\n");
printf("N: %d, D: %d, H: %d, W: %d, C: %d\n", N, D, H, W, C);
printf("Window Z: %d, Y: %d, X: %d, Stride Z: %d, Y: %d, X: %d\n", Z, Y, X, Sz, Sy, Sx);
printf("Output Do: %d, Ho: %d, Wo: %d\n", Do, Ho, Wo);
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
constexpr bool OutputIndex = true;
constexpr bool PropagateNan = false;
// Shapes / strides / parameters (NDHWC)
const auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
const auto input_strides = ck_tile::make_tuple(D * H * W * C, H * W * C, W * C, C, 1);
const auto output_strides = ck_tile::make_tuple(Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1);
const auto window_spatial_lengths = ck_tile::make_tuple(Z, Y, X);
const auto window_strides = ck_tile::make_tuple(Sz, Sy, Sx);
const auto window_dilations = ck_tile::make_tuple(Dz, Dy, Dx);
const auto input_left_pads = ck_tile::make_tuple(LeftPz, LeftPy, LeftPx);
const auto input_right_pads = ck_tile::make_tuple(RightPz, RightPy, RightPx);
ck_tile::HostTensor<InDataType> in({N, D, H, W, C}, {D * H * W * C, H * W * C, W * C, C, 1});
ck_tile::HostTensor<OutDataType> out({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<OutDataType> out_ref({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<IndexDataType> out_index({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<IndexDataType> out_ref_index({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(in);
ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_index_buf(OutputIndex ? out_index.get_element_space_size_in_bytes() : 0);
in_buf.ToDevice(in.data());
using ReduceOp = ck_tile::ReduceOp::Max;
using BlockWarps = ck_tile::sequence<1, 1>;
using BlockTile = ck_tile::sequence<128, 1>;
using WarpTile = ck_tile::sequence<128, 1>;
using ThreadTile = ck_tile::sequence<2, 1>;
using Shape = ck_tile::PoolShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
using Problem = ck_tile::PoolProblem<InDataType,
OutDataType,
ComputeDataType,
IndexDataType,
ReduceOp,
OutputIndex,
PropagateNan,
Shape>;
using Kernel = ck_tile::PoolKernel<Problem>;
constexpr ck_tile::index_t kBlockPerCu = 1;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
static_cast<InDataType*>(in_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_buf.GetDeviceBuffer()),
OutputIndex ? static_cast<IndexDataType*>(out_index_buf.GetDeviceBuffer()) : nullptr,
input_shape,
output_shape,
input_strides,
output_strides,
window_spatial_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
auto kernel_args = Kernel::MakeKernelArgs(host_args);
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
std::cout << "grid size " << kGridSize << std::endl;
// Validate kernel can handle the given configuration
if(!Kernel::IsSupportedArgument(kernel_args))
{
throw std::runtime_error("ERROR: Kernel arguments are not supported! \n");
}
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
std::size_t num_btype =
sizeof(InDataType) * N * D * H * W * C + sizeof(OutDataType) * N * Do * Ho * Wo * C;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(do_validation)
{
out_buf.FromDevice(out.mData.data());
ck_tile::reference_pool3d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOp,
decltype(input_shape),
decltype(window_spatial_lengths),
OutputIndex>(in, out_ref, out_ref_index, kernel_args, ReduceOp{});
if constexpr(OutputIndex)
{
out_index_buf.FromDevice(out_index.mData.data());
pass = ck_tile::check_err(out, out_ref) && ck_tile::check_err(out_index, out_ref_index);
}
else
{
pass = ck_tile::check_err(out, out_ref);
}
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
return run<ck_tile::half_t, ck_tile::half_t, float, ck_tile::index_t>(arg_parser) ? 0 : -2;
}