[CK_TILE] Add pooling in tile_engine (#4469)

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->
Add pooling in ck tile engine

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
aledudek
2026-04-01 09:31:46 +02:00
committed by GitHub
parent 9b8b2456b4
commit 357a140e7b
25 changed files with 3258 additions and 19 deletions

View File

@@ -103,6 +103,42 @@ struct Max
}
};
struct Min
{
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::max();
};
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return min(y, x);
}
// Overload with changed flag for index tracking
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
{
T new_min = min(y, x);
if(x < y)
{
changed = true;
}
return new_min;
}
};
struct AbsMax
{
template <

View File

@@ -70,3 +70,4 @@ add_subdirectory(gemm_tile_engine)
add_subdirectory(pooling)
add_subdirectory(grouped_conv)
add_subdirectory(gemm_streamk_tile_engine)
add_subdirectory(pooling_tile_engine)

View File

@@ -234,7 +234,7 @@ message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# GPU architecture filtering - only build tests for supported architectures
set(GEMM_TEST_GPU_TARGETS "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx12-generic")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -245,7 +245,7 @@ endforeach()
# Early exit if no compatible GPU architectures are available
if(NOT GEMM_TEST_GPU_TARGETS)
message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()

View File

@@ -232,7 +232,7 @@ message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# GPU architecture filtering - only build tests for supported architectures
set(GEMM_TEST_GPU_TARGETS "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201;gfx12-generic")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -243,7 +243,7 @@ endforeach()
# Early exit if no compatible GPU architectures are available
if(NOT GEMM_TEST_GPU_TARGETS)
message(WARNING "Skipping GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()

View File

@@ -0,0 +1,341 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ============================================================================
# Pooling Tile Engine Unit Tests
#
# This CMake file creates unit tests for tile_engine generated pooling kernels.
# Each kernel configuration gets its own test executable.
# ============================================================================
# Locate tile_engine pooling scripts directory
set(TILE_ENGINE_POOLING_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/pooling")
if(NOT EXISTS ${TILE_ENGINE_POOLING_DIR})
message(WARNING "Tile engine pooling directory not found: ${TILE_ENGINE_POOLING_DIR}")
return()
endif()
# ============================================================================
# create_individual_pool_test_target
#
# Creates a single test executable for a specific pooling kernel configuration.
#
# Parameters:
# datatype - Data type (fp16, fp32, bf16)
# config_name - Configuration file name without .json extension
# trait - Kernel trait combination string
# tile_config - Tile configuration parameters
# config_json - Full path to JSON configuration file
# ============================================================================
function(create_individual_pool_test_target datatype config_name kernel_name trait tile_config config_json)
set(target_name "test_pooling_tile_engine_${datatype}_${config_name}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${config_name}")
# Generated header path (already created during cmake configuration)
# Use kernel_name from pool_kernel_list.txt to match the filename generated by pooling_instance_builder.py
set(test_header "${working_path}/pooling_single_${kernel_name}.hpp")
# Determine pooling dimension from trait string (format: reduce_op_output_index_propagate_nan_pooling_dim)
# The pooling_dim is the last field: "2d" or "3d"
string(REGEX MATCH "[23]d$" kernel_pooling_dim "${trait}")
if(kernel_pooling_dim STREQUAL "3d")
set(test_params_header "${working_path}/test_params_3d.hpp")
set(pooling_dim_value 3)
else()
set(test_params_header "${working_path}/test_params_2d.hpp")
set(pooling_dim_value 2)
endif()
# Verify header exists
if(NOT EXISTS ${test_header})
message(WARNING "Generated header not found: ${test_header}")
return()
endif()
# Verify test parameters header exists
if(NOT EXISTS ${test_params_header})
message(WARNING "Test parameters header not found: ${test_params_header}")
return()
endif()
# Create GTest executable for this kernel configuration
add_gtest_executable(${target_name}
${CMAKE_CURRENT_SOURCE_DIR}/test_pooling_simple.cpp
)
# Configure GPU architectures for HIP compilation
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${POOLING_TEST_GPU_TARGETS})
# Define preprocessor macros for generated header location, test parameters, and pooling dimension
target_compile_definitions(${target_name} PRIVATE
POOLING_SINGLE_INSTANCE_HPP="${test_header}"
POOLING_TEST_PARAMS_HPP="${test_params_header}"
POOLING_DIM_VALUE=${pooling_dim_value}
)
# Include directories for headers and dependencies
target_include_directories(${target_name} PRIVATE
${PROJECT_SOURCE_DIR}/include
${PROJECT_BINARY_DIR}/include
${PROJECT_SOURCE_DIR} # Root directory for tile_engine access
${GTEST_INCLUDE_DIRS}
)
# Compiler options matching tile_engine requirements
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
-include ${test_header}
)
# Add FP8 format definitions for proper data type interpretation
if(CK_USE_OCP_FP8)
target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8)
endif()
message(STATUS " Created test target: ${target_name}")
endfunction()
# ============================================================================
# build_pool_test_targets
#
# Builds all test targets for a specific datatype/config combination.
# Uses tile_engine's two-step process: list kernels, then generate tests.
#
# Parameters:
# datatype - Data type (fp16, fp32, bf16)
# config_name - Configuration file name without .json extension
# ============================================================================
function(build_pool_test_targets datatype config_name)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${config_name}")
# Locate and validate configuration file
set(config_filename "${config_name}.json")
set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}")
if(NOT EXISTS ${json_blob})
message(WARNING "Test config file not found: ${json_blob}")
return()
endif()
# Prepare build directory for this configuration
file(MAKE_DIRECTORY ${working_path})
# STEP 1: Discovery phase - list all valid kernel configurations
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_POOLING_DIR}/pooling_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--config_json ${json_blob}
--list_kernels
WORKING_DIRECTORY ${TILE_ENGINE_POOLING_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(WARNING "Failed to list pooling kernels for ${datatype}_${config_name}: ${list_error}")
return()
endif()
# Verify kernel list file was generated
if(NOT EXISTS ${working_path}/pool_kernel_list.txt)
message(STATUS "No pooling kernels found for ${datatype}_${config_name}")
return()
endif()
message(STATUS "Building pooling tests for ${datatype}_${config_name}")
# STEP 2a: Extract test parameters from config for BOTH 2D and 3D dimensions.
# Each kernel's pooling_dim is embedded in its trait string, so we generate
# separate test_params headers and select the right one per kernel target.
set(test_params_file_2d "${working_path}/test_params_2d.hpp")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py
--config_file ${json_blob}
--output_file ${test_params_file_2d}
--pooling_dim 2d
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE extract_ret_2d
OUTPUT_VARIABLE extract_output_2d
ERROR_VARIABLE extract_error_2d
)
if(NOT extract_ret_2d EQUAL 0)
message(WARNING "Failed to extract 2D test parameters for pooling ${datatype}: ${extract_error_2d}")
return()
endif()
set(test_params_file_3d "${working_path}/test_params_3d.hpp")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py
--config_file ${json_blob}
--output_file ${test_params_file_3d}
--pooling_dim 3d
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE extract_ret_3d
OUTPUT_VARIABLE extract_output_3d
ERROR_VARIABLE extract_error_3d
)
if(NOT extract_ret_3d EQUAL 0)
message(WARNING "Failed to extract 3D test parameters for pooling ${datatype}: ${extract_error_3d}")
return()
endif()
# STEP 2c: Header generation phase - generate headers using --gen_single
message(STATUS " Generating pooling headers using --gen_single...")
file(STRINGS ${working_path}/pool_kernel_list.txt kernel_lines)
set(gen_count 0)
foreach(line IN LISTS kernel_lines)
# Parse kernel specification format: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(LENGTH parts parts_len)
if(parts_len EQUAL 3)
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Generate header using --gen_single
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_POOLING_DIR}/pooling_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--config_json ${json_blob}
--gen_single
--kernel_name "${kernel_name}"
--tile_config "${tile_config}"
--trait_combo "${trait_combo}"
WORKING_DIRECTORY ${TILE_ENGINE_POOLING_DIR}
RESULT_VARIABLE gen_ret
OUTPUT_VARIABLE gen_output
ERROR_VARIABLE gen_error
)
if(NOT gen_ret EQUAL 0)
message(WARNING "Failed to generate pooling header for ${kernel_name}: ${gen_error}")
else()
math(EXPR gen_count "${gen_count} + 1")
endif()
endif()
endforeach()
message(STATUS " Generated ${gen_count} pooling headers for ${datatype}")
# STEP 3: Target creation phase - create test targets
message(STATUS " Creating pooling test targets...")
file(STRINGS ${working_path}/pool_kernel_list.txt kernel_lines)
set(test_count 0)
foreach(line IN LISTS kernel_lines)
string(REPLACE "|" ";" parts "${line}")
list(LENGTH parts parts_len)
if(parts_len EQUAL 3)
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
create_individual_pool_test_target("${datatype}" "${config_name}" "${kernel_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
math(EXPR test_count "${test_count} + 1")
endif()
endforeach()
message(STATUS " Created ${test_count} pooling test targets for ${datatype}")
endfunction()
# ============================================================================
# MAIN EXECUTION - Test Target Generation
# ============================================================================
message(STATUS "=== Starting Pooling Tile Engine Test Configuration ===")
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# GPU architecture filtering - only build tests for supported architectures
set(POOLING_TEST_GPU_TARGETS "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND POOLING_TEST_GPU_TARGETS ${target})
message(STATUS " Adding GPU target for pooling tests: ${target}")
endif()
endforeach()
# Early exit if no compatible GPU architectures are available
if(NOT POOLING_TEST_GPU_TARGETS)
message(WARNING "Skipping Pooling Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
message(STATUS "Building Pooling tile engine tests for GPU targets: ${POOLING_TEST_GPU_TARGETS}")
# Enable parallel compilation optimizations
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4
compile_normal=16
)
# Enable compiler cache if available and explicitly requested
option(ENABLE_CCACHE_TESTS "Enable ccache for test compilation" OFF)
if(ENABLE_CCACHE_TESTS)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(STATUS "Using ccache for faster test compilation")
else()
message(WARNING "ccache requested but not found")
endif()
else()
message(STATUS "ccache disabled for tests (use -DENABLE_CCACHE_TESTS=ON to enable)")
endif()
# ============================================================================
# Test Configuration Matrix
# ============================================================================
set(TEST_DATATYPES "fp16;fp32")
# ============================================================================
# Test Target Generation
# ============================================================================
# 1. SIMPLE TEST: Basic functionality validation (always built)
set(SIMPLE_TEST_CONFIG "simple_test_config")
set(SIMPLE_TEST_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SIMPLE_TEST_CONFIG}.json")
if(EXISTS ${SIMPLE_TEST_CONFIG_FILE})
message(STATUS "Processing pooling simple test config: ${SIMPLE_TEST_CONFIG}")
foreach(datatype IN LISTS TEST_DATATYPES)
build_pool_test_targets("${datatype}" "${SIMPLE_TEST_CONFIG}")
endforeach()
else()
message(WARNING "Pooling simple test config file not found: ${SIMPLE_TEST_CONFIG_FILE}")
endif()
# 2. COVERAGE LEVEL: Quick or comprehensive testing
# Quick: ~2 kernels (1 tile config × 1 trait combo × fp16/fp32) from simple config only
# Comprehensive: ~200+ kernels with extensive tile sizes, warp configurations, and all trait combinations
set(POOLING_COVERAGE_LEVEL "quick" CACHE STRING "Pooling coverage level: quick or comprehensive")
set_property(CACHE POOLING_COVERAGE_LEVEL PROPERTY STRINGS "quick" "comprehensive")
if(POOLING_COVERAGE_LEVEL STREQUAL "comprehensive")
set(COMPREHENSIVE_CONFIG "comprehensive_coverage_config")
set(COMPREHENSIVE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${COMPREHENSIVE_CONFIG}.json")
if(EXISTS ${COMPREHENSIVE_CONFIG_FILE})
message(STATUS "Processing pooling comprehensive coverage config: ${COMPREHENSIVE_CONFIG}")
foreach(datatype IN LISTS TEST_DATATYPES)
build_pool_test_targets("${datatype}" "${COMPREHENSIVE_CONFIG}")
endforeach()
else()
message(WARNING "Pooling comprehensive config file not found: ${COMPREHENSIVE_CONFIG_FILE}")
endif()
elseif(NOT POOLING_COVERAGE_LEVEL STREQUAL "quick")
message(FATAL_ERROR "Invalid POOLING_COVERAGE_LEVEL: ${POOLING_COVERAGE_LEVEL}. Must be 'quick' or 'comprehensive'")
endif()
message(STATUS "Pooling tile engine tests configured:")
message(STATUS " - Simple test: fp16/fp32 (always)")
message(STATUS " - Coverage level: ${POOLING_COVERAGE_LEVEL}")
message(STATUS " Use -DPOOLING_COVERAGE_LEVEL=comprehensive for extensive testing")

View File

@@ -0,0 +1,87 @@
# Pooling Tile Engine Tests
Unit tests for pooling kernels generated by the tile_engine pooling codegen system.
## Overview
These tests validate pooling kernels that are generated at CMake configuration time
by `pooling_instance_builder.py`. Each kernel configuration (tile shape + traits)
gets its own GTest executable that verifies correctness against a CPU reference
implementation.
## Architecture
```
test/ck_tile/pooling_tile_engine/
├── CMakeLists.txt # Build infrastructure
├── configs/
│ └── simple_test_config.json # Test configuration with problem sizes
├── extract_test_params.py # Extracts problem sizes to C++ header
├── test_pooling_simple.cpp # GTest driver (parameterized)
└── README.md # This file
```
### Build Flow
1. **CMake configuration**: `CMakeLists.txt` invokes `pooling_instance_builder.py --list_kernels`
to discover valid kernel configurations from the JSON config.
2. **Parameter extraction**: `extract_test_params.py` generates `test_params.hpp` with
problem sizes from the JSON config.
3. **Header generation**: For each kernel, `pooling_instance_builder.py --gen_single`
generates a C++ header defining `SelectedKernel` with the specific tile configuration.
4. **Compilation**: Each kernel gets a separate test executable compiled with the
generated header via `-include`.
5. **Execution**: GTest runs each problem size as a separate test case, comparing
device results against the CPU reference.
## Configuration
### `simple_test_config.json`
Defines:
- **tile_config**: Block/warp/thread tile dimensions for PoolShape
- **trait_config**: Reduce op (max/avg), output_index, propagate_nan, pooling_dim (2d/3d)
- **test_params**: Problem sizes (N, H, W, C, window, stride, dilation, padding)
### Supported configurations
- **Data types**: fp16, fp32
- **Reduce operations**: max (with index output)
- **Pooling dimensions**: 2D (NHWC), 3D (NDHWC)
- **GPU targets**: gfx90a, gfx942
## Building
```bash
# From the build directory:
cmake --build . --target test_pooling_tile_engine_fp16_simple_test_config_max_true_false_2d_128x1_1x1_128x1_2x1
# Or build all pooling tests:
cmake --build . --target tests
```
## Running
```bash
# Run a specific test:
./test_pooling_tile_engine_fp16_simple_test_config_max_true_false_2d_128x1_1x1_128x1_2x1
# Run with GTest filters:
./test_pooling_tile_engine_fp16_simple_test_config_max_true_false_2d_128x1_1x1_128x1_2x1 --gtest_filter="*BasicFunctionality*"
```
## Relationship to tile_engine
The tile_engine pooling op lives at `tile_engine/ops/pooling/` and provides:
- `pooling_instance_builder.py` - Codegen for kernel headers
- `pooling_validation_utils.py` - Configuration validation
- `pooling_common.hpp` - Shared trait definitions
- `pooling_benchmark.hpp` - Problem/metric definitions
- `pooling_benchmark_single.cpp` - Single-kernel benchmark entry point
The underlying ck_tile pooling kernel lives at `include/ck_tile/ops/pooling/` and provides:
- `PoolKernel` - GPU kernel implementation
- `PoolProblem` - Problem parameterization
- `PoolShape` - Tile shape specification
- `PoolDefaultPolicy` - Tile distribution and reduction policies

View File

@@ -0,0 +1,165 @@
{
"problem": {
"description": "Comprehensive pooling coverage testing - multiple block sizes (64-512), warp configurations, thread tile sizes, and all trait combinations (max/avg, index, NaN propagation). Approximately 200+ kernels."
},
"test_params": {
"problem_sizes_2d": [
{
"_comment": "Basic: small tensor, 2x2 window, stride 2, no padding",
"N": 1, "H": 8, "W": 8, "C": 32,
"Y": 2, "X": 2,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0
},
{
"_comment": "Padded 3x3: moderate tensor with symmetric padding, stride 1 (overlapping)",
"N": 1, "H": 16, "W": 16, "C": 64,
"Y": 3, "X": 3,
"stride_h": 1, "stride_w": 1,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 1, "pad_h_right": 1,
"pad_w_left": 1, "pad_w_right": 1
},
{
"_comment": "Large channels: stress-test the C dimension",
"N": 1, "H": 16, "W": 16, "C": 256,
"Y": 2, "X": 2,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0
},
{
"_comment": "Large batch: multi-batch correctness",
"N": 4, "H": 16, "W": 16, "C": 32,
"Y": 2, "X": 2,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0
},
{
"_comment": "Non-square spatial: rectangular H != W",
"N": 2, "H": 32, "W": 16, "C": 64,
"Y": 3, "X": 3,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 1, "pad_h_right": 1,
"pad_w_left": 1, "pad_w_right": 1
},
{
"_comment": "Large window 5x5: bigger receptive field",
"N": 1, "H": 32, "W": 32, "C": 32,
"Y": 5, "X": 5,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 2, "pad_h_right": 2,
"pad_w_left": 2, "pad_w_right": 2
},
{
"_comment": "Large window 7x7: global-style pooling",
"N": 1, "H": 14, "W": 14, "C": 128,
"Y": 7, "X": 7,
"stride_h": 1, "stride_w": 1,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 3, "pad_h_right": 3,
"pad_w_left": 3, "pad_w_right": 3
},
{
"_comment": "Dilated: dilation_h=2, dilation_w=2 with 3x3 window",
"N": 1, "H": 32, "W": 32, "C": 64,
"Y": 3, "X": 3,
"stride_h": 1, "stride_w": 1,
"dilation_h": 2, "dilation_w": 2,
"pad_h_left": 2, "pad_h_right": 2,
"pad_w_left": 2, "pad_w_right": 2
},
{
"_comment": "Asymmetric padding: different left/right padding",
"N": 2, "H": 16, "W": 16, "C": 32,
"Y": 3, "X": 3,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 0, "pad_h_right": 1,
"pad_w_left": 0, "pad_w_right": 1
},
{
"_comment": "Large spatial: bigger feature maps",
"N": 1, "H": 64, "W": 64, "C": 64,
"Y": 2, "X": 2,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0
},
{
"_comment": "Non-square window: Y != X",
"N": 1, "H": 32, "W": 32, "C": 32,
"Y": 3, "X": 2,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 1, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0
},
{
"_comment": "Stride-1 overlap: overlapping 2x2 windows",
"N": 2, "H": 16, "W": 16, "C": 64,
"Y": 2, "X": 2,
"stride_h": 1, "stride_w": 1,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0
}
],
"problem_sizes_3d": [
{
"_comment": "Basic 3D: small volume, 2x2x2 window",
"N": 1, "D": 4, "H": 4, "W": 4, "C": 32,
"Z": 2, "Y": 2, "X": 2,
"stride_d": 2, "stride_h": 2, "stride_w": 2,
"dilation_d": 1, "dilation_h": 1, "dilation_w": 1,
"pad_d_left": 0, "pad_d_right": 0,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0
},
{
"_comment": "Padded 3D: with symmetric padding",
"N": 1, "D": 8, "H": 8, "W": 8, "C": 32,
"Z": 3, "Y": 3, "X": 3,
"stride_d": 2, "stride_h": 2, "stride_w": 2,
"dilation_d": 1, "dilation_h": 1, "dilation_w": 1,
"pad_d_left": 1, "pad_d_right": 1,
"pad_h_left": 1, "pad_h_right": 1,
"pad_w_left": 1, "pad_w_right": 1
},
{
"_comment": "Multi-batch 3D: larger batch and channels",
"N": 2, "D": 8, "H": 8, "W": 8, "C": 64,
"Z": 2, "Y": 2, "X": 2,
"stride_d": 2, "stride_h": 2, "stride_w": 2,
"dilation_d": 1, "dilation_h": 1, "dilation_w": 1,
"pad_d_left": 0, "pad_d_right": 0,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0
}
]
},
"tile_config": {
"block_m": {"values": [64, 128, 256, 512]},
"block_n": {"values": [1]},
"warp_m": {"values": [1, 2, 4]},
"warp_n": {"values": [1]},
"warp_tile_m": {"values": [64, 128, 256]},
"warp_tile_n": {"values": [1]},
"thread_tile_m": {"values": [1, 2, 4]},
"thread_tile_n": {"values": [1]}
},
"trait_config": {
"reduce_op": {"values": ["max", "avg"]},
"output_index": {"values": [true, false]},
"propagate_nan": {"values": [true, false]},
"pooling_dim": {"values": ["2d", "3d"]}
}
}

View File

@@ -0,0 +1,60 @@
{
"problem": {
"description": "Basic pooling functionality validation with moderate problem sizes"
},
"test_params": {
"problem_sizes_2d": [
{
"N": 1, "H": 8, "W": 8, "C": 32,
"Y": 2, "X": 2,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0
},
{
"N": 2, "H": 16, "W": 16, "C": 32,
"Y": 3, "X": 3,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 1, "pad_h_right": 1,
"pad_w_left": 1, "pad_w_right": 1
},
{
"N": 1, "H": 32, "W": 32, "C": 64,
"Y": 2, "X": 2,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0
}
],
"problem_sizes_3d": [
{
"N": 1, "D": 4, "H": 4, "W": 4, "C": 32,
"Z": 2, "Y": 2, "X": 2,
"stride_d": 2, "stride_h": 2, "stride_w": 2,
"dilation_d": 1, "dilation_h": 1, "dilation_w": 1,
"pad_d_left": 0, "pad_d_right": 0,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0
}
]
},
"tile_config": {
"block_m": {"values": [128]},
"block_n": {"values": [1]},
"warp_m": {"values": [1]},
"warp_n": {"values": [1]},
"warp_tile_m": {"values": [128]},
"warp_tile_n": {"values": [1]},
"thread_tile_m": {"values": [2]},
"thread_tile_n": {"values": [1]}
},
"trait_config": {
"reduce_op": {"values": ["max"]},
"output_index": {"values": [true]},
"propagate_nan": {"values": [false]},
"pooling_dim": {"values": ["2d"]}
}
}

View File

@@ -0,0 +1,168 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Extract pooling test parameters from config JSON and write to C++ header.
Generates test_params.hpp with problem sizes for parameterized GTest.
"""
import json
import argparse
import os
from pathlib import Path
def extract_test_params(config_file, output_file, pooling_dim="2d"):
"""Extract test parameters from config JSON and write to output file"""
with open(config_file, "r") as f:
config = json.load(f)
# Extract test parameters based on pooling dimension
test_params = []
if pooling_dim == "2d":
if "test_params" in config and "problem_sizes_2d" in config["test_params"]:
test_params = config["test_params"]["problem_sizes_2d"]
else:
# Default 2D test parameters
test_params = [
{
"N": 1,
"H": 8,
"W": 8,
"C": 32,
"Y": 2,
"X": 2,
"stride_h": 2,
"stride_w": 2,
"dilation_h": 1,
"dilation_w": 1,
"pad_h_left": 0,
"pad_h_right": 0,
"pad_w_left": 0,
"pad_w_right": 0,
},
{
"N": 2,
"H": 16,
"W": 16,
"C": 32,
"Y": 3,
"X": 3,
"stride_h": 2,
"stride_w": 2,
"dilation_h": 1,
"dilation_w": 1,
"pad_h_left": 1,
"pad_h_right": 1,
"pad_w_left": 1,
"pad_w_right": 1,
},
]
else: # 3d
if "test_params" in config and "problem_sizes_3d" in config["test_params"]:
test_params = config["test_params"]["problem_sizes_3d"]
else:
# Default 3D test parameters
test_params = [
{
"N": 1,
"D": 4,
"H": 4,
"W": 4,
"C": 32,
"Z": 2,
"Y": 2,
"X": 2,
"stride_d": 2,
"stride_h": 2,
"stride_w": 2,
"dilation_d": 1,
"dilation_h": 1,
"dilation_w": 1,
"pad_d_left": 0,
"pad_d_right": 0,
"pad_h_left": 0,
"pad_h_right": 0,
"pad_w_left": 0,
"pad_w_right": 0,
},
]
# Write to output file in C++ format
output_dir = Path(output_file).parent
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_file, "w") as f:
f.write("// Generated test parameters for pooling tile_engine tests\n")
f.write("// This file is auto-generated during CMake configuration\n\n")
if pooling_dim == "2d":
f.write(
"static const std::vector<PoolTestParams2D> CONFIG_TEST_PARAMS = {\n"
)
for i, params in enumerate(test_params):
comma = "," if i < len(test_params) - 1 else ""
f.write(
f" {{"
f"{params['N']}, {params['H']}, {params['W']}, {params['C']}, "
f"{params['Y']}, {params['X']}, "
f"{params['stride_h']}, {params['stride_w']}, "
f"{params['dilation_h']}, {params['dilation_w']}, "
f"{params['pad_h_left']}, {params['pad_h_right']}, "
f"{params['pad_w_left']}, {params['pad_w_right']}"
f"}}{comma}\n"
)
f.write("};\n")
else: # 3d
f.write(
"static const std::vector<PoolTestParams3D> CONFIG_TEST_PARAMS = {\n"
)
for i, params in enumerate(test_params):
comma = "," if i < len(test_params) - 1 else ""
f.write(
f" {{"
f"{params['N']}, {params['D']}, {params['H']}, {params['W']}, {params['C']}, "
f"{params['Z']}, {params['Y']}, {params['X']}, "
f"{params['stride_d']}, {params['stride_h']}, {params['stride_w']}, "
f"{params['dilation_d']}, {params['dilation_h']}, {params['dilation_w']}, "
f"{params['pad_d_left']}, {params['pad_d_right']}, "
f"{params['pad_h_left']}, {params['pad_h_right']}, "
f"{params['pad_w_left']}, {params['pad_w_right']}"
f"}}{comma}\n"
)
f.write("};\n")
print(
f"Extracted {len(test_params)} {pooling_dim} test parameters from {config_file} -> {output_file}"
)
def main():
parser = argparse.ArgumentParser(
description="Extract pooling test parameters from config JSON"
)
parser.add_argument("--config_file", required=True, help="Input config JSON file")
parser.add_argument(
"--output_file", required=True, help="Output test parameters file"
)
parser.add_argument(
"--pooling_dim",
default="2d",
choices=["2d", "3d"],
help="Pooling dimension (2d or 3d)",
)
args = parser.parse_args()
if not os.path.exists(args.config_file):
print(f"Error: Config file not found: {args.config_file}")
return 1
extract_test_params(args.config_file, args.output_file, args.pooling_dim)
return 0
if __name__ == "__main__":
exit(main())

View File

@@ -0,0 +1,435 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* @file test_pooling_simple.cpp
* @brief Unit tests for pooling kernels generated by pooling_instance_builder
*
* This test includes kernels generated during CMake configuration by
* pooling_instance_builder.py and tests them with problem sizes extracted
* from the corresponding JSON configuration files.
*/
#include <gtest/gtest.h>
#include <iostream>
#include <string>
#include <string_view>
#include <vector>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/pooling.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
#include "tile_engine/ops/pooling/pooling_common.hpp"
// The kernel header is included via compile command line with -include flag
// It defines: SelectedKernel, KERNEL_NAME, InDataType, OutDataType,
// ComputeDataType, IndexDataType, ReduceOpType,
// TensorShape, WindowShape, POOLING_DIM
// ============================================================================
// Test parameter structures
// ============================================================================
/// @brief Test parameters for 2D pooling
struct PoolTestParams2D
{
int N, H, W, C; // Input dimensions (NHWC)
int Y, X; // Window size
int stride_h, stride_w; // Strides
int dilation_h, dilation_w; // Dilations
int pad_h_left, pad_h_right; // Height padding
int pad_w_left, pad_w_right; // Width padding
};
/// @brief Test parameters for 3D pooling
struct PoolTestParams3D
{
int N, D, H, W, C; // Input dimensions (NDHWC)
int Z, Y, X; // Window size
int stride_d, stride_h, stride_w; // Strides
int dilation_d, dilation_h, dilation_w; // Dilations
int pad_d_left, pad_d_right; // Depth padding
int pad_h_left, pad_h_right; // Height padding
int pad_w_left, pad_w_right; // Width padding
};
// Include config-specific test parameters (after parameter structs are defined)
#ifdef POOLING_TEST_PARAMS_HPP
#include POOLING_TEST_PARAMS_HPP
#endif
// POOLING_DIM_VALUE is set by CMake as a compile definition:
// 2 for 2D pooling kernels, 3 for 3D pooling kernels.
// This selects the appropriate test class and parameterization at compile time.
#if POOLING_DIM_VALUE == 2
// ============================================================================
// 2D Pooling Tests
// ============================================================================
class PoolingTileEngineTest2D : public ::testing::TestWithParam<PoolTestParams2D>
{
protected:
void SetUp() override
{
auto params = GetParam();
N_ = params.N;
H_ = params.H;
W_ = params.W;
C_ = params.C;
Y_ = params.Y;
X_ = params.X;
stride_h_ = params.stride_h;
stride_w_ = params.stride_w;
dilation_h_ = params.dilation_h;
dilation_w_ = params.dilation_w;
pad_h_left_ = params.pad_h_left;
pad_h_right_ = params.pad_h_right;
pad_w_left_ = params.pad_w_left;
pad_w_right_ = params.pad_w_right;
// Calculate output dimensions
ck_tile::index_t Ys = (Y_ - 1) * dilation_h_ + 1;
ck_tile::index_t Xs = (X_ - 1) * dilation_w_ + 1;
Ho_ = (H_ + pad_h_left_ + pad_h_right_ - Ys) / stride_h_ + 1;
Wo_ = (W_ + pad_w_left_ + pad_w_right_ - Xs) / stride_w_ + 1;
}
int N_, H_, W_, C_;
int Y_, X_;
int stride_h_, stride_w_;
int dilation_h_, dilation_w_;
int pad_h_left_, pad_h_right_;
int pad_w_left_, pad_w_right_;
int Ho_, Wo_;
};
TEST_P(PoolingTileEngineTest2D, BasicFunctionality)
{
// Create host tensors
ck_tile::HostTensor<InDataType> h_in({N_, H_, W_, C_});
ck_tile::HostTensor<OutDataType> h_out({N_, Ho_, Wo_, C_});
ck_tile::HostTensor<OutDataType> h_out_ref({N_, Ho_, Wo_, C_});
ck_tile::HostTensor<IndexDataType> h_out_index({N_, Ho_, Wo_, C_});
ck_tile::HostTensor<IndexDataType> h_out_ref_index({N_, Ho_, Wo_, C_});
// Initialize input with random data
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
h_out.SetZero();
h_out_ref.SetZero();
// Device memory
ck_tile::DeviceMem d_in(h_in.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out(h_out.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out_index(h_out_index.get_element_space_size_in_bytes());
d_in.ToDevice(h_in.data());
d_out.SetZero();
d_out_index.SetZero();
// Build shapes and strides (NHWC layout)
auto input_shape = ck_tile::make_tuple(N_, H_, W_, C_);
auto output_shape = ck_tile::make_tuple(N_, Ho_, Wo_, C_);
auto input_strides = ck_tile::make_tuple(H_ * W_ * C_, W_ * C_, C_, 1);
auto output_strides = ck_tile::make_tuple(Ho_ * Wo_ * C_, Wo_ * C_, C_, 1);
auto window_lengths = ck_tile::make_tuple(Y_, X_);
auto window_strides = ck_tile::make_tuple(stride_h_, stride_w_);
auto window_dilations = ck_tile::make_tuple(dilation_h_, dilation_w_);
auto input_left_pads = ck_tile::make_tuple(pad_h_left_, pad_w_left_);
auto input_right_pads = ck_tile::make_tuple(pad_h_right_, pad_w_right_);
// Build host args for the generated kernel
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
d_in.GetDeviceBuffer(),
d_out.GetDeviceBuffer(),
d_out_index.GetDeviceBuffer(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
// Stream config: no timing overhead for fastest execution
ck_tile::stream_config stream_config{nullptr, false, 0, 0, 1, false, false, 1};
// Launch generated kernel
try
{
SelectedKernel::launch(host_args, stream_config);
}
catch(const std::exception& e)
{
std::string error_msg(e.what());
if(error_msg.find("Arguments not supported") != std::string::npos)
{
GTEST_SKIP() << "Configuration not supported: " << e.what();
}
else
{
FAIL() << "Kernel launch failed: " << e.what();
}
}
// Copy results back
d_out.FromDevice(h_out.data());
d_out_index.FromDevice(h_out_index.data());
// Compute reference on host
auto kernel_args_ref = ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
h_in.data(),
h_out_ref.data(),
h_out_ref_index.data(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::reference_pool2d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOpType,
decltype(input_shape),
decltype(window_lengths),
SelectedKernel::kOutputIndex>(
h_in, h_out_ref, h_out_ref_index, kernel_args_ref, ReduceOpType{});
// Verify value results
bool pass_value = ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
EXPECT_TRUE(pass_value) << "Pooling value verification failed for " << KERNEL_NAME;
// Verify index results if output_index is enabled
if constexpr(SelectedKernel::kOutputIndex)
{
bool pass_index =
ck_tile::check_err(h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
EXPECT_TRUE(pass_index) << "Pooling index verification failed for " << KERNEL_NAME;
}
}
TEST_P(PoolingTileEngineTest2D, KernelInfo)
{
EXPECT_TRUE(std::string_view(KERNEL_NAME).size() > 0) << "Kernel name should not be empty";
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
std::cout << "Problem size: N=" << N_ << " H=" << H_ << " W=" << W_ << " C=" << C_
<< " Window=" << Y_ << "x" << X_ << " Output=" << Ho_ << "x" << Wo_ << std::endl;
}
// Instantiate test suite with config-specific test parameters
// CONFIG_TEST_PARAMS is defined in the auto-generated test_params_2d.hpp file
INSTANTIATE_TEST_SUITE_P(PoolingVerification,
PoolingTileEngineTest2D,
::testing::ValuesIn(CONFIG_TEST_PARAMS),
[](const ::testing::TestParamInfo<PoolTestParams2D>& param_info) {
return "N" + std::to_string(param_info.param.N) + "_H" +
std::to_string(param_info.param.H) + "_W" +
std::to_string(param_info.param.W) + "_C" +
std::to_string(param_info.param.C) + "_Y" +
std::to_string(param_info.param.Y) + "_X" +
std::to_string(param_info.param.X);
});
#elif POOLING_DIM_VALUE == 3
// ============================================================================
// 3D Pooling Tests
// ============================================================================
class PoolingTileEngineTest3D : public ::testing::TestWithParam<PoolTestParams3D>
{
protected:
void SetUp() override
{
auto params = GetParam();
N_ = params.N;
D_ = params.D;
H_ = params.H;
W_ = params.W;
C_ = params.C;
Z_ = params.Z;
Y_ = params.Y;
X_ = params.X;
stride_d_ = params.stride_d;
stride_h_ = params.stride_h;
stride_w_ = params.stride_w;
dilation_d_ = params.dilation_d;
dilation_h_ = params.dilation_h;
dilation_w_ = params.dilation_w;
pad_d_left_ = params.pad_d_left;
pad_d_right_ = params.pad_d_right;
pad_h_left_ = params.pad_h_left;
pad_h_right_ = params.pad_h_right;
pad_w_left_ = params.pad_w_left;
pad_w_right_ = params.pad_w_right;
// Calculate output dimensions
ck_tile::index_t Zs = (Z_ - 1) * dilation_d_ + 1;
ck_tile::index_t Ys = (Y_ - 1) * dilation_h_ + 1;
ck_tile::index_t Xs = (X_ - 1) * dilation_w_ + 1;
Do_ = (D_ + pad_d_left_ + pad_d_right_ - Zs) / stride_d_ + 1;
Ho_ = (H_ + pad_h_left_ + pad_h_right_ - Ys) / stride_h_ + 1;
Wo_ = (W_ + pad_w_left_ + pad_w_right_ - Xs) / stride_w_ + 1;
}
int N_, D_, H_, W_, C_;
int Z_, Y_, X_;
int stride_d_, stride_h_, stride_w_;
int dilation_d_, dilation_h_, dilation_w_;
int pad_d_left_, pad_d_right_;
int pad_h_left_, pad_h_right_;
int pad_w_left_, pad_w_right_;
int Do_, Ho_, Wo_;
};
TEST_P(PoolingTileEngineTest3D, BasicFunctionality)
{
// Create host tensors (NDHWC layout)
ck_tile::HostTensor<InDataType> h_in({N_, D_, H_, W_, C_});
ck_tile::HostTensor<OutDataType> h_out({N_, Do_, Ho_, Wo_, C_});
ck_tile::HostTensor<OutDataType> h_out_ref({N_, Do_, Ho_, Wo_, C_});
ck_tile::HostTensor<IndexDataType> h_out_index({N_, Do_, Ho_, Wo_, C_});
ck_tile::HostTensor<IndexDataType> h_out_ref_index({N_, Do_, Ho_, Wo_, C_});
// Initialize input with random data
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
h_out.SetZero();
h_out_ref.SetZero();
// Device memory
ck_tile::DeviceMem d_in(h_in.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out(h_out.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out_index(h_out_index.get_element_space_size_in_bytes());
d_in.ToDevice(h_in.data());
d_out.SetZero();
d_out_index.SetZero();
// Build shapes and strides (NDHWC layout)
auto input_shape = ck_tile::make_tuple(N_, D_, H_, W_, C_);
auto output_shape = ck_tile::make_tuple(N_, Do_, Ho_, Wo_, C_);
auto input_strides = ck_tile::make_tuple(D_ * H_ * W_ * C_, H_ * W_ * C_, W_ * C_, C_, 1);
auto output_strides =
ck_tile::make_tuple(Do_ * Ho_ * Wo_ * C_, Ho_ * Wo_ * C_, Wo_ * C_, C_, 1);
auto window_lengths = ck_tile::make_tuple(Z_, Y_, X_);
auto window_strides = ck_tile::make_tuple(stride_d_, stride_h_, stride_w_);
auto window_dilations = ck_tile::make_tuple(dilation_d_, dilation_h_, dilation_w_);
auto input_left_pads = ck_tile::make_tuple(pad_d_left_, pad_h_left_, pad_w_left_);
auto input_right_pads = ck_tile::make_tuple(pad_d_right_, pad_h_right_, pad_w_right_);
// Build host args for the generated kernel
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
d_in.GetDeviceBuffer(),
d_out.GetDeviceBuffer(),
d_out_index.GetDeviceBuffer(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
// Stream config: no timing overhead for fastest execution
ck_tile::stream_config stream_config{nullptr, false, 0, 0, 1, false, false, 1};
// Launch generated kernel
try
{
SelectedKernel::launch(host_args, stream_config);
}
catch(const std::exception& e)
{
std::string error_msg(e.what());
if(error_msg.find("Arguments not supported") != std::string::npos)
{
GTEST_SKIP() << "Configuration not supported: " << e.what();
}
else
{
FAIL() << "Kernel launch failed: " << e.what();
}
}
// Copy results back
d_out.FromDevice(h_out.data());
d_out_index.FromDevice(h_out_index.data());
// Compute reference on host
auto kernel_args_ref = ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
h_in.data(),
h_out_ref.data(),
h_out_ref_index.data(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::reference_pool3d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOpType,
decltype(input_shape),
decltype(window_lengths),
SelectedKernel::kOutputIndex>(
h_in, h_out_ref, h_out_ref_index, kernel_args_ref, ReduceOpType{});
// Verify value results
bool pass_value = ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
EXPECT_TRUE(pass_value) << "Pooling 3D value verification failed for " << KERNEL_NAME;
// Verify index results if output_index is enabled
if constexpr(SelectedKernel::kOutputIndex)
{
bool pass_index =
ck_tile::check_err(h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
EXPECT_TRUE(pass_index) << "Pooling 3D index verification failed for " << KERNEL_NAME;
}
}
TEST_P(PoolingTileEngineTest3D, KernelInfo)
{
EXPECT_TRUE(std::string_view(KERNEL_NAME).size() > 0) << "Kernel name should not be empty";
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
std::cout << "Problem size: N=" << N_ << " D=" << D_ << " H=" << H_ << " W=" << W_
<< " C=" << C_ << " Window=" << Z_ << "x" << Y_ << "x" << X_ << " Output=" << Do_
<< "x" << Ho_ << "x" << Wo_ << std::endl;
}
// Instantiate test suite with config-specific test parameters
// CONFIG_TEST_PARAMS is defined in the auto-generated test_params_3d.hpp file
INSTANTIATE_TEST_SUITE_P(PoolingVerification,
PoolingTileEngineTest3D,
::testing::ValuesIn(CONFIG_TEST_PARAMS),
[](const ::testing::TestParamInfo<PoolTestParams3D>& param_info) {
return "N" + std::to_string(param_info.param.N) + "_D" +
std::to_string(param_info.param.D) + "_H" +
std::to_string(param_info.param.H) + "_W" +
std::to_string(param_info.param.W) + "_C" +
std::to_string(param_info.param.C) + "_Z" +
std::to_string(param_info.param.Z) + "_Y" +
std::to_string(param_info.param.Y) + "_X" +
std::to_string(param_info.param.X);
});
#else
#error "POOLING_DIM_VALUE must be 2 or 3"
#endif

View File

@@ -7,5 +7,6 @@ include_directories(BEFORE
add_subdirectory(ops/gemm)
add_subdirectory(ops/gemm_streamk)
add_subdirectory(ops/pooling)
add_subdirectory(ops/reduce)

View File

@@ -231,7 +231,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, gfx950
set(GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx12-generic")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -242,7 +242,7 @@ endforeach()
# Skip build if no matching targets found
if(NOT GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM Multi D build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine GEMM Multi D build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM Multi D targets for GPU targets: ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}")

View File

@@ -219,7 +219,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, and gfx950
set(GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx12-generic")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -230,7 +230,7 @@ endforeach()
# Skip build if no matching targets found
if(NOT GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM Preshuffle targets for GPU targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")

View File

@@ -226,7 +226,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
set(GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201;gfx12-generic")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -237,7 +237,7 @@ endforeach()
# Skip build if no matching targets found
if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM Universal targets for GPU targets: ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}")

View File

@@ -25,6 +25,16 @@ ELEMENT_SIZE_MAP = {
"fp64": 8,
}
def get_warp_size_for_gpu(gpu_target: str) -> int:
"""Get the warp size for a given GPU target.
CDNA architectures (gfx9xx) use WAVE64 (64 threads per wavefront).
RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use WAVE32 (32 threads per wavefront).
"""
if gpu_target.startswith("gfx9"):
return 64 # CDNA - WAVE64
return 32 # RDNA and others - WAVE32
WARP_SUPPORTED_COMBINATIONS = {
"gfx90a": [
[1, 4, 1],
@@ -586,10 +596,11 @@ def validate_whole_wg_cover_configuration(
layout,
a_datatype,
b_datatype,
gpu_target: str = "gfx90a",
) -> Tuple[bool, str]:
# Validate whole workgroup cover configuration
warp_size = 64
warp_size = get_warp_size_for_gpu(gpu_target)
NumWarps = warp_m * warp_n * warp_k
BlockSize = NumWarps * warp_size
@@ -704,6 +715,73 @@ def wg_cover_core_validation(
return True, ""
def validate_cshuffle_epilogue_distribution(
tile_m: int,
tile_n: int,
warp_m: int,
warp_n: int,
warp_k: int,
warp_tile_m: int,
warp_tile_n: int,
warp_size: int,
c_datatype: str,
) -> Tuple[bool, str]:
"""
Validate that the CShuffleEpilogue tile distribution pattern is valid.
This mirrors the static_assert in static_encoding_pattern.hpp:
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
The CShuffleEpilogue creates a tile_distribution_encoding_pattern_2d<BlockSize, YPerTile, XPerTile, VecSize, thread_raked>
where:
- BlockSize = warp_m * warp_n * warp_k * warp_size
- YPerTile = MPerIterationShuffle (derived from tile_m / (warp_m * warp_tile_m / some_factor))
- XPerTile = NPerIterationShuffle (derived from tile_n)
- VecSize = vector size based on element size (typically 8 for fp16)
The key constraint is that X0 must evenly divide warp_size, where:
- X0 = min(warp_size, XPerTile / X1)
- X1 = min(VecSize, LargestVec)
- LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
"""
NumWarps = warp_m * warp_n * warp_k
BlockSize = NumWarps * warp_size
elem_size = ELEMENT_SIZE_MAP.get(c_datatype, 2)
VecSize = 16 // elem_size
XPerTile = tile_n
YPerTile = tile_m // warp_m
if XPerTile <= 0 or YPerTile <= 0:
return False, f"Invalid tile dimensions: XPerTile={XPerTile}, YPerTile={YPerTile}"
num_warps = BlockSize // warp_size
if num_warps * warp_size == 0:
return False, "Invalid BlockSize or warp_size"
LargestVec = (XPerTile * YPerTile) // (num_warps * warp_size)
if LargestVec <= 0:
LargestVec = 1
X1 = min(VecSize, LargestVec) if LargestVec > 0 else VecSize
if X1 <= 0:
X1 = 1
X0 = min(warp_size, XPerTile // X1) if X1 > 0 else warp_size
Y1 = warp_size // X0 if X0 > 0 else 0
if X0 * Y1 != warp_size:
return (
False,
f"CShuffleEpilogue distribution invalid: X0({X0}) * Y1({Y1}) = {X0 * Y1} != warp_size({warp_size}). "
f"XPerTile={XPerTile}, YPerTile={YPerTile}, VecSize={VecSize}, BlockSize={BlockSize}"
)
return True, ""
def get_global_vector_load_size(
BlockSize: int,
KPerBlock: int,
@@ -766,6 +844,8 @@ def validate_gemm(
trait_name: str = None,
) -> bool:
# GEMM Validation
warp_size = get_warp_size_for_gpu(gpu_target)
# Validate whole workgroup cover configuration
whole_workgroup_cover_valid, whole_workgroup_cover_error = (
validate_whole_wg_cover_configuration(
@@ -778,6 +858,7 @@ def validate_gemm(
layout,
a_datatype,
b_datatype,
gpu_target,
)
)
if not whole_workgroup_cover_valid:
@@ -786,6 +867,23 @@ def validate_gemm(
)
return False, whole_workgroup_cover_error
# Validate CShuffleEpilogue distribution pattern (for cshuffle epilogue)
# This validation ensures the tile distribution pattern is valid for the output tile
cshuffle_valid, cshuffle_error = validate_cshuffle_epilogue_distribution(
tile_m,
tile_n,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_size,
c_datatype,
)
if not cshuffle_valid:
logging.debug(f"CShuffleEpilogue validation failed: {cshuffle_error}")
return False, cshuffle_error
return True, ""
@@ -808,6 +906,8 @@ def validate_gemm_preshuffle(
trait_name: str = None,
) -> bool:
# Preshuffle Validations
warp_size = get_warp_size_for_gpu(gpu_target)
# Validate vector load alignment
m_iter_per_warp = tile_m / (warp_m * warp_tile_m)
vector_valid, vector_error = validate_vector_load_alignment(
@@ -815,7 +915,7 @@ def validate_gemm_preshuffle(
warp_tile_k,
a_datatype,
m_iter_per_warp,
wave_size=64,
wave_size=warp_size,
vector_load_size=16,
)
if not vector_valid:
@@ -831,7 +931,7 @@ def validate_gemm_preshuffle(
warp_k,
a_datatype,
vector_load_size=16,
warp_size=64,
warp_size=warp_size,
)
if not m0_m1_m2_valid:
logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}")

View File

@@ -226,7 +226,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx942, gfx950
set(GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx942;gfx950")
set(DESIRED_TARGETS "gfx942;gfx950;gfx12-generic")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -237,7 +237,7 @@ endforeach()
# Skip build if no matching targets found
if(NOT GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine Grouped GEMM build: No supported GPU targets (gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine Grouped GEMM build: No supported GPU targets (gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual Grouped GEMM targets for GPU targets: ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}")

View File

@@ -216,7 +216,7 @@ message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942") # TODO: Add gfx950 when supported
set(DESIRED_TARGETS "gfx90a;gfx942;gfx12-generic") # TODO: Add gfx950 when supported
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -227,7 +227,7 @@ endforeach()
# Skip build if no matching targets found
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")

View File

@@ -0,0 +1,212 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ============================================================================
# Pooling Tile Engine Build Configuration
#
# Generates individual benchmark executables for pooling kernels
# ============================================================================
set(POOLING_DATATYPE "fp8;fp16;fp32" CACHE STRING "List of datatypes for Pooling (semicolon-separated)")
set(POOLING_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
option(ENABLE_CCACHE_POOLING "Enable ccache for pooling ops compilation" OFF)
# Store the directory path for use in functions
set(POOLING_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
# ============================================================================
# create_individual_pool_target
#
# Creates a single benchmark executable for a specific pooling kernel config.
# ============================================================================
function(create_individual_pool_target datatype kernel_name trait tile_config config_json)
if(NOT POOLING_GPU_TARGETS)
message(WARNING "Skipping individual pooling target: No supported GPU targets")
return()
endif()
set(target_name "benchmark_pooling_${datatype}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}")
# HIP clang offload uses temporary files derived from the input source basename.
# When many targets compile the same source filename in parallel, temporary
# files can collide and corrupt each other. Use a unique copied source per target.
set(target_source "${CMAKE_CURRENT_BINARY_DIR}/${target_name}_pooling_benchmark_single.cpp")
# Generated header path - use kernel_name from pool_kernel_list.txt to match
# the filename generated by pooling_instance_builder.py
set(instance_header "${working_path}/pooling_single_${kernel_name}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
OUTPUT ${instance_header}
COMMAND ${Python3_EXECUTABLE} ${POOLING_SOURCE_DIR}/pooling_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--config_json ${config_json}
--gen_single
--kernel_name "${kernel_name}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
DEPENDS ${POOLING_SOURCE_DIR}/pooling_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
configure_file(${POOLING_SOURCE_DIR}/pooling_benchmark_single.cpp ${target_source} COPYONLY)
# Create the executable
add_executable(${target_name}
${target_source}
${instance_header}
)
# Set GPU architectures
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${POOLING_GPU_TARGETS})
# Set compile definitions
target_compile_definitions(${target_name} PRIVATE
POOLING_SINGLE_INSTANCE_HPP="${instance_header}"
)
# Include directories
target_include_directories(${target_name} PRIVATE
${POOLING_SOURCE_DIR}
${working_path}
)
# Compile options
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
-include ${instance_header}
)
# Add FP8 format definitions if needed
if(CK_USE_OCP_FP8)
target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8)
endif()
# Add to collection targets
add_dependencies(benchmark_pooling_all ${target_name})
add_dependencies(benchmark_pooling_${datatype} ${target_name})
message(STATUS " Created pooling benchmark target: ${target_name}")
endfunction()
# ============================================================================
# build_individual_pool_targets
#
# Builds all benchmark targets for a specific datatype.
# ============================================================================
function(build_individual_pool_targets datatype)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}")
# Choose config file
if(DEFINED ENV{POOLING_CONFIG_FILE} AND NOT "$ENV{POOLING_CONFIG_FILE}" STREQUAL "")
set(config_filename "$ENV{POOLING_CONFIG_FILE}")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
message(STATUS " Using config from environment variable: ${config_filename}")
elseif(NOT "${POOLING_CONFIG_FILE}" STREQUAL "")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${POOLING_CONFIG_FILE}")
message(STATUS " Using custom config: ${POOLING_CONFIG_FILE}")
else()
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
message(STATUS " Using default config for pooling")
endif()
if(NOT EXISTS ${json_blob})
message(FATAL_ERROR "Config file not found: ${json_blob}")
endif()
file(MAKE_DIRECTORY ${working_path})
# Step 1: List kernels
message(STATUS " Listing pooling kernel configurations for ${datatype}...")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/pooling_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--config_json ${json_blob}
--list_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list pooling kernels for ${datatype}: ${list_error}")
endif()
# Read kernel count
if(EXISTS ${working_path}/pool_kernel_count.txt)
file(READ ${working_path}/pool_kernel_count.txt kernel_count)
string(STRIP "${kernel_count}" kernel_count)
message(STATUS " Found ${kernel_count} pooling kernel configurations")
else()
message(FATAL_ERROR "Pooling kernel count file not found")
endif()
# Step 2: Create targets
if(EXISTS ${working_path}/pool_kernel_list.txt)
file(STRINGS ${working_path}/pool_kernel_list.txt kernel_lines)
foreach(line IN LISTS kernel_lines)
string(REPLACE "|" ";" parts "${line}")
list(LENGTH parts parts_len)
if(parts_len EQUAL 3)
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
create_individual_pool_target("${datatype}" "${kernel_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
endif()
endforeach()
else()
message(FATAL_ERROR "Pooling kernel list file not found")
endif()
endfunction()
# ============================================================================
# MAIN EXECUTION
# ============================================================================
message(STATUS "=== Starting Tile Engine Pooling Configuration ===")
message(STATUS "POOLING_DATATYPE: ${POOLING_DATATYPE}")
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets
set(POOLING_GPU_TARGETS "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND POOLING_GPU_TARGETS ${target})
message(STATUS " Adding GPU target for pooling: ${target}")
endif()
endforeach()
if(NOT POOLING_GPU_TARGETS)
message(WARNING "Skipping Tile Engine Pooling build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(STATUS "Building pooling targets for GPU targets: ${POOLING_GPU_TARGETS}")
# Enable ccache if requested
if(ENABLE_CCACHE_POOLING)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(STATUS "Using ccache for pooling compilation")
endif()
endif()
# Create collection targets
add_custom_target(benchmark_pooling_all)
foreach(dt IN LISTS POOLING_DATATYPE)
add_custom_target(benchmark_pooling_${dt})
endforeach()
# Build targets for each datatype
foreach(dt IN LISTS POOLING_DATATYPE)
build_individual_pool_targets(${dt})
endforeach()
endif()

View File

@@ -0,0 +1,21 @@
{
"problem": {
"description": "Default pooling configuration for tile_engine benchmarks"
},
"tile_config": {
"block_m": {"values": [64,128,256]},
"block_n": {"values": [1,2]},
"warp_m": {"values": [1]},
"warp_n": {"values": [1]},
"warp_tile_m": {"values": [128]},
"warp_tile_n": {"values": [1]},
"thread_tile_m": {"values": [1,2,4]},
"thread_tile_n": {"values": [1]}
},
"trait_config": {
"reduce_op": {"values": ["max", "min", "avg"]},
"output_index": {"values": [true, false]},
"propagate_nan": {"values": [true, false]},
"pooling_dim": {"values": ["2d", "3d"]}
}
}

View File

@@ -0,0 +1,132 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include <vector>
#include <numeric>
#include <algorithm>
#include <iostream>
#include <sstream>
#include <cmath>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/pooling.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
namespace ck_tile {
/// @brief Performance metrics for benchmarking
enum class PoolMetric
{
LATENCY,
BANDWIDTH
};
/// @brief Pooling problem specification for 2D pooling
struct PoolProblem2D
{
index_t N, H, W, C; // Input dimensions (NHWC)
index_t Y, X; // Window dimensions
index_t stride_h, stride_w; // Window strides
index_t dilation_h, dilation_w; // Window dilations
index_t pad_h_left, pad_h_right; // Height padding
index_t pad_w_left, pad_w_right; // Width padding
std::string datatype; // Data type name
std::string reduce_op; // "max", "min", or "avg"
index_t Ho() const
{
index_t Ys = (Y - 1) * dilation_h + 1;
return (H + pad_h_left + pad_h_right - Ys) / stride_h + 1;
}
index_t Wo() const
{
index_t Xs = (X - 1) * dilation_w + 1;
return (W + pad_w_left + pad_w_right - Xs) / stride_w + 1;
}
index_t input_elements() const { return N * H * W * C; }
index_t output_elements() const { return N * Ho() * Wo() * C; }
std::string to_string() const
{
std::ostringstream oss;
oss << "N" << N << "_H" << H << "_W" << W << "_C" << C << "_Y" << Y << "_X" << X << "_Sh"
<< stride_h << "_Sw" << stride_w << "_Dh" << dilation_h << "_Dw" << dilation_w;
if(pad_h_left > 0 || pad_w_left > 0)
oss << "_Ph" << pad_h_left << "_Pw" << pad_w_left;
return oss.str();
}
};
/// @brief Pooling problem specification for 3D pooling
struct PoolProblem3D
{
index_t N, D, H, W, C; // Input dimensions (NDHWC)
index_t Z, Y, X; // Window dimensions
index_t stride_d, stride_h, stride_w; // Window strides
index_t dilation_d, dilation_h, dilation_w; // Window dilations
index_t pad_d_left, pad_d_right; // Depth padding
index_t pad_h_left, pad_h_right; // Height padding
index_t pad_w_left, pad_w_right; // Width padding
std::string datatype; // Data type name
std::string reduce_op; // "max", "min", or "avg"
index_t Do() const
{
index_t Zs = (Z - 1) * dilation_d + 1;
return (D + pad_d_left + pad_d_right - Zs) / stride_d + 1;
}
index_t Ho() const
{
index_t Ys = (Y - 1) * dilation_h + 1;
return (H + pad_h_left + pad_h_right - Ys) / stride_h + 1;
}
index_t Wo() const
{
index_t Xs = (X - 1) * dilation_w + 1;
return (W + pad_w_left + pad_w_right - Xs) / stride_w + 1;
}
index_t input_elements() const { return N * D * H * W * C; }
index_t output_elements() const { return N * Do() * Ho() * Wo() * C; }
std::string to_string() const
{
std::ostringstream oss;
oss << "N" << N << "_D" << D << "_H" << H << "_W" << W << "_C" << C << "_Z" << Z << "_Y"
<< Y << "_X" << X;
return oss.str();
}
};
/// @brief Performance result for a pooling kernel
struct PoolPerformanceResult
{
float latency_ms;
float bandwidth_gb_s;
std::string to_string() const
{
std::ostringstream oss;
oss << "latency=" << latency_ms << "ms, bandwidth=" << bandwidth_gb_s << "GB/s";
return oss.str();
}
};
/// @brief Benchmark settings
struct PoolBenchmarkSetting
{
int warmup = 5;
int repeat = 20;
bool verify = true;
int init_method = 0; // 0: uniform random, 1: integer sequence, 2: constant, 3: special
};
} // namespace ck_tile

View File

@@ -0,0 +1,390 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* @file pooling_benchmark_single.cpp
* @brief Single-kernel benchmark for pooling operations (2D and 3D).
*
* This benchmark includes the generated kernel header via -include flag
* and runs the pooling kernel with specified problem sizes.
*
* The generated header provides:
* - SelectedKernel (struct with ::launch())
* - KERNEL_NAME (constexpr const char*)
* - POOLING_DIM (constexpr int, 2 or 3)
* - InDataType, OutDataType, ComputeDataType, IndexDataType, ReduceOpType
* - TensorShape, WindowShape
*/
#include <iostream>
#include <string>
#include <vector>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/pooling.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
#include "pooling_common.hpp"
#include "pooling_benchmark.hpp"
// The kernel header is included via compile command line with -include flag.
// --------------------------------------------------------------------------
// Benchmark implementation — templated on pooling dimension so that only
// the matching branch is instantiated (2D or 3D).
// --------------------------------------------------------------------------
template <typename HostArgs>
static float launch_selected_kernel(HostArgs& args, const ck_tile::stream_config& stream)
{
return SelectedKernel::launch(args, stream);
}
template <int PoolDim>
static int benchmark_pooling(int argc, char* argv[])
{
if constexpr(PoolDim == 2)
{
// ---- 2D argument parser ----
ck_tile::ArgParser arg_parser;
arg_parser.insert("n", "1", "Batch size (N)")
.insert("h", "16", "Input height (H)")
.insert("w", "16", "Input width (W)")
.insert("c", "32", "Channels (C)")
.insert("wy", "2", "Window height (Y)")
.insert("wx", "2", "Window width (X)")
.insert("sy", "2", "Window stride height")
.insert("sx", "2", "Window stride width")
.insert("dy", "1", "Window dilation height")
.insert("dx", "1", "Window dilation width")
.insert("phy", "0", "Padding height left")
.insert("phyr", "0", "Padding height right")
.insert("pwx", "0", "Padding width left")
.insert("pwxr", "0", "Padding width right")
.insert("verify", "1", "Verify results (0/1)")
.insert("warmup", "5", "Warmup iterations")
.insert("repeat", "20", "Repeat iterations")
.insert("log", "1", "Log level");
if(!arg_parser.parse(argc, argv))
return -1;
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t H = arg_parser.get_int("h");
ck_tile::index_t W = arg_parser.get_int("w");
ck_tile::index_t C = arg_parser.get_int("c");
ck_tile::index_t Y = arg_parser.get_int("wy");
ck_tile::index_t X = arg_parser.get_int("wx");
ck_tile::index_t Sy = arg_parser.get_int("sy");
ck_tile::index_t Sx = arg_parser.get_int("sx");
ck_tile::index_t Dy = arg_parser.get_int("dy");
ck_tile::index_t Dx = arg_parser.get_int("dx");
ck_tile::index_t LeftPy = arg_parser.get_int("phy");
ck_tile::index_t RightPy = arg_parser.get_int("phyr");
ck_tile::index_t LeftPx = arg_parser.get_int("pwx");
ck_tile::index_t RightPx = arg_parser.get_int("pwxr");
bool verify = arg_parser.get_int("verify") != 0;
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int log_level = arg_parser.get_int("log");
ck_tile::index_t Ys = (Y - 1) * Dy + 1;
ck_tile::index_t Xs = (X - 1) * Dx + 1;
ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
std::cout << "Pooling 2D benchmark: " << KERNEL_NAME << std::endl;
std::cout << " Input: NHWC = " << N << "x" << H << "x" << W << "x" << C << std::endl;
std::cout << " Output: NHWC = " << N << "x" << Ho << "x" << Wo << "x" << C << std::endl;
std::cout << " Window: " << Y << "x" << X << ", stride: " << Sy << "x" << Sx
<< ", dilation: " << Dy << "x" << Dx << std::endl;
ck_tile::HostTensor<InDataType> h_in({N, H, W, C});
ck_tile::HostTensor<OutDataType> h_out({N, Ho, Wo, C});
ck_tile::HostTensor<OutDataType> h_out_ref({N, Ho, Wo, C});
ck_tile::HostTensor<IndexDataType> h_out_index({N, Ho, Wo, C});
ck_tile::HostTensor<IndexDataType> h_out_ref_index({N, Ho, Wo, C});
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
ck_tile::DeviceMem d_in(h_in.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out(h_out.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out_index(h_out_index.get_element_space_size_in_bytes());
d_in.ToDevice(h_in.data());
d_out.SetZero();
d_out_index.SetZero();
auto input_shape = ck_tile::make_tuple(N, H, W, C);
auto output_shape = ck_tile::make_tuple(N, Ho, Wo, C);
auto input_strides = ck_tile::make_tuple(H * W * C, W * C, C, ck_tile::index_t{1});
auto output_strides = ck_tile::make_tuple(Ho * Wo * C, Wo * C, C, ck_tile::index_t{1});
auto window_lengths = ck_tile::make_tuple(Y, X);
auto window_strides = ck_tile::make_tuple(Sy, Sx);
auto window_dilations = ck_tile::make_tuple(Dy, Dx);
auto input_left_pads = ck_tile::make_tuple(LeftPy, LeftPx);
auto input_right_pads = ck_tile::make_tuple(RightPy, RightPx);
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
d_in.GetDeviceBuffer(),
d_out.GetDeviceBuffer(),
d_out_index.GetDeviceBuffer(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::stream_config stream{nullptr, true, log_level, warmup, repeat};
float latency = 0;
try
{
latency = launch_selected_kernel(host_args, stream);
}
catch(const std::exception& e)
{
std::cerr << "Kernel launch failed: " << e.what() << std::endl;
return -1;
}
size_t bytes_read = static_cast<size_t>(N) * H * W * C * sizeof(InDataType);
size_t bytes_written = static_cast<size_t>(N) * Ho * Wo * C * sizeof(OutDataType);
float bandwidth = (bytes_read + bytes_written) / (latency * 1e-3f) / 1e9f;
std::cout << " Latency: " << latency << " ms" << std::endl;
std::cout << " Bandwidth: " << bandwidth << " GB/s" << std::endl;
if(verify)
{
d_out.FromDevice(h_out.data());
d_out_index.FromDevice(h_out_index.data());
auto kernel_args =
ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
h_in.data(),
h_out_ref.data(),
h_out_ref_index.data(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::reference_pool2d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOpType,
decltype(input_shape),
decltype(window_lengths),
SelectedKernel::kOutputIndex>(
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
bool pass_value =
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-3, 1e-3);
std::cout << " Verification: " << (pass_value ? "PASS" : "FAIL") << std::endl;
if(SelectedKernel::kOutputIndex)
{
bool pass_index = ck_tile::check_err(
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
std::cout << " Index verification: " << (pass_index ? "PASS" : "FAIL")
<< std::endl;
}
}
return 0;
}
else // PoolDim == 3
{
// ---- 3D argument parser ----
ck_tile::ArgParser arg_parser;
arg_parser.insert("n", "1", "Batch size (N)")
.insert("d", "4", "Input depth (D)")
.insert("h", "16", "Input height (H)")
.insert("w", "16", "Input width (W)")
.insert("c", "32", "Channels (C)")
.insert("wz", "2", "Window depth (Z)")
.insert("wy", "2", "Window height (Y)")
.insert("wx", "2", "Window width (X)")
.insert("sz", "2", "Window stride depth")
.insert("sy", "2", "Window stride height")
.insert("sx", "2", "Window stride width")
.insert("dz", "1", "Window dilation depth")
.insert("dy", "1", "Window dilation height")
.insert("dx", "1", "Window dilation width")
.insert("pdz", "0", "Padding depth left")
.insert("pdzr", "0", "Padding depth right")
.insert("phy", "0", "Padding height left")
.insert("phyr", "0", "Padding height right")
.insert("pwx", "0", "Padding width left")
.insert("pwxr", "0", "Padding width right")
.insert("verify", "1", "Verify results (0/1)")
.insert("warmup", "5", "Warmup iterations")
.insert("repeat", "20", "Repeat iterations")
.insert("log", "1", "Log level");
if(!arg_parser.parse(argc, argv))
return -1;
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t D = arg_parser.get_int("d");
ck_tile::index_t H = arg_parser.get_int("h");
ck_tile::index_t W = arg_parser.get_int("w");
ck_tile::index_t C = arg_parser.get_int("c");
ck_tile::index_t Z = arg_parser.get_int("wz");
ck_tile::index_t Y = arg_parser.get_int("wy");
ck_tile::index_t X = arg_parser.get_int("wx");
ck_tile::index_t Sz = arg_parser.get_int("sz");
ck_tile::index_t Sy = arg_parser.get_int("sy");
ck_tile::index_t Sx = arg_parser.get_int("sx");
ck_tile::index_t Dz = arg_parser.get_int("dz");
ck_tile::index_t Dy = arg_parser.get_int("dy");
ck_tile::index_t Dx = arg_parser.get_int("dx");
ck_tile::index_t LeftPz = arg_parser.get_int("pdz");
ck_tile::index_t RightPz = arg_parser.get_int("pdzr");
ck_tile::index_t LeftPy = arg_parser.get_int("phy");
ck_tile::index_t RightPy = arg_parser.get_int("phyr");
ck_tile::index_t LeftPx = arg_parser.get_int("pwx");
ck_tile::index_t RightPx = arg_parser.get_int("pwxr");
bool verify = arg_parser.get_int("verify") != 0;
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int log_level = arg_parser.get_int("log");
ck_tile::index_t Zs = (Z - 1) * Dz + 1;
ck_tile::index_t Ys = (Y - 1) * Dy + 1;
ck_tile::index_t Xs = (X - 1) * Dx + 1;
ck_tile::index_t Do = (D + LeftPz + RightPz - Zs) / Sz + 1;
ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
std::cout << "Pooling 3D benchmark: " << KERNEL_NAME << std::endl;
std::cout << " Input: NDHWC = " << N << "x" << D << "x" << H << "x" << W << "x" << C
<< std::endl;
std::cout << " Output: NDHWC = " << N << "x" << Do << "x" << Ho << "x" << Wo << "x" << C
<< std::endl;
std::cout << " Window: " << Z << "x" << Y << "x" << X << ", stride: " << Sz << "x" << Sy
<< "x" << Sx << ", dilation: " << Dz << "x" << Dy << "x" << Dx << std::endl;
ck_tile::HostTensor<InDataType> h_in({N, D, H, W, C});
ck_tile::HostTensor<OutDataType> h_out({N, Do, Ho, Wo, C});
ck_tile::HostTensor<OutDataType> h_out_ref({N, Do, Ho, Wo, C});
ck_tile::HostTensor<IndexDataType> h_out_index({N, Do, Ho, Wo, C});
ck_tile::HostTensor<IndexDataType> h_out_ref_index({N, Do, Ho, Wo, C});
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
ck_tile::DeviceMem d_in(h_in.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out(h_out.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out_index(h_out_index.get_element_space_size_in_bytes());
d_in.ToDevice(h_in.data());
d_out.SetZero();
d_out_index.SetZero();
auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
auto input_strides =
ck_tile::make_tuple(D * H * W * C, H * W * C, W * C, C, ck_tile::index_t{1});
auto output_strides =
ck_tile::make_tuple(Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, ck_tile::index_t{1});
auto window_lengths = ck_tile::make_tuple(Z, Y, X);
auto window_strides = ck_tile::make_tuple(Sz, Sy, Sx);
auto window_dilations = ck_tile::make_tuple(Dz, Dy, Dx);
auto input_left_pads = ck_tile::make_tuple(LeftPz, LeftPy, LeftPx);
auto input_right_pads = ck_tile::make_tuple(RightPz, RightPy, RightPx);
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
d_in.GetDeviceBuffer(),
d_out.GetDeviceBuffer(),
d_out_index.GetDeviceBuffer(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::stream_config stream{nullptr, true, log_level, warmup, repeat};
float latency = 0;
try
{
latency = launch_selected_kernel(host_args, stream);
}
catch(const std::exception& e)
{
std::cerr << "Kernel launch failed: " << e.what() << std::endl;
return -1;
}
size_t bytes_read = static_cast<size_t>(N) * D * H * W * C * sizeof(InDataType);
size_t bytes_written = static_cast<size_t>(N) * Do * Ho * Wo * C * sizeof(OutDataType);
float bandwidth = (bytes_read + bytes_written) / (latency * 1e-3f) / 1e9f;
std::cout << " Latency: " << latency << " ms" << std::endl;
std::cout << " Bandwidth: " << bandwidth << " GB/s" << std::endl;
if(verify)
{
d_out.FromDevice(h_out.data());
d_out_index.FromDevice(h_out_index.data());
auto kernel_args =
ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
h_in.data(),
h_out_ref.data(),
h_out_ref_index.data(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::reference_pool3d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOpType,
decltype(input_shape),
decltype(window_lengths),
SelectedKernel::kOutputIndex>(
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
bool pass_value =
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-3, 1e-3);
std::cout << " Verification: " << (pass_value ? "PASS" : "FAIL") << std::endl;
if(SelectedKernel::kOutputIndex)
{
bool pass_index = ck_tile::check_err(
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
std::cout << " Index verification: " << (pass_index ? "PASS" : "FAIL")
<< std::endl;
}
}
return 0;
}
}
int main(int argc, char* argv[]) { return benchmark_pooling<POOLING_DIM>(argc, argv); }

View File

@@ -0,0 +1,52 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include <sstream>
#include <iostream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/pooling.hpp"
namespace ck_tile {
/// @brief Kernel trait parameters for pooling tile_engine configurations
struct PoolingKernelTraits
{
std::string reduce_op; // "max", "min", or "avg"
bool output_index; // Whether to output indices (max pooling)
bool propagate_nan; // Whether to propagate NaN values
bool cross_warp; // Whether cross-warp reduction is used
std::string to_string() const
{
std::ostringstream oss;
oss << reduce_op << "_" << (output_index ? "idx" : "noidx") << "_"
<< (propagate_nan ? "nan" : "nonan") << "_"
<< (cross_warp ? "crosswarp" : "nocrosswarp");
return oss.str();
}
};
/// @brief Extract traits from a kernel name string
inline PoolingKernelTraits extract_pooling_traits_from_name(const std::string& name)
{
PoolingKernelTraits traits;
if(name.find("max") != std::string::npos)
traits.reduce_op = "max";
else if(name.find("min") != std::string::npos)
traits.reduce_op = "min";
else
traits.reduce_op = "avg";
traits.output_index =
(name.find("idx") != std::string::npos) && (name.find("noidx") == std::string::npos);
traits.propagate_nan =
(name.find("nan") != std::string::npos) && (name.find("nonan") == std::string::npos);
traits.cross_warp = (name.find("crosswarp") != std::string::npos) &&
(name.find("nocrosswarp") == std::string::npos);
return traits;
}
} // namespace ck_tile

View File

@@ -0,0 +1,551 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Pooling kernel instance builder for tile_engine.
Generates C++ kernel headers for pooling operations with specific tile
configurations and trait combinations.
Usage:
--list_kernels: List valid kernel configurations
--gen_single: Generate a single kernel header
--gen_individual: Generate all kernel headers
"""
import os
import json
import argparse
import itertools
import multiprocessing
import concurrent.futures
from pathlib import Path
import logging
from pooling_validation_utils import (
is_tile_config_valid,
is_trait_combination_valid,
get_dtype_string,
get_reduce_op_string,
)
logger = logging.getLogger(__name__)
class PoolingKernelBuilder:
def __init__(self, working_path, datatype, config_json=None):
self.working_path = Path(working_path)
self.datatype = datatype
self.config_json = config_json
# Create working directory if it doesn't exist
self.working_path.mkdir(parents=True, exist_ok=True)
# Load configuration
if config_json and os.path.exists(config_json):
with open(config_json, "r") as f:
self.config = json.load(f)
else:
self.config = self._get_default_config()
def _get_default_config(self):
"""Return default configuration if no config file is provided"""
return {
"tile_config": {
"block_m": {"values": [64,128,256]},
"block_n": {"values": [1,2]},
"warp_m": {"values": [1]},
"warp_n": {"values": [1]},
"warp_tile_m": {"values": [128]},
"warp_tile_n": {"values": [1]},
"thread_tile_m": {"values": [1,2,4]},
"thread_tile_n": {"values": [1]},
},
"trait_config": {
"reduce_op": {"values": ["max", "min", "avg"]},
"output_index": {"values": [True, False]},
"propagate_nan": {"values": [True, False]},
"pooling_dim": {"values": ["2d", "3d"]},
},
}
def _get_tile_configs(self, fast_mode=False):
"""Get tile configurations from config"""
if "tile_config" not in self.config:
return []
tile_config = self.config["tile_config"]
block_m_values = tile_config.get("block_m", {}).get("values", [64,128,256])
block_n_values = tile_config.get("block_n", {}).get("values", [1,2])
warp_m_values = tile_config.get("warp_m", {}).get("values", [1])
warp_n_values = tile_config.get("warp_n", {}).get("values", [1])
warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [128])
warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [1])
thread_tile_m_values = tile_config.get("thread_tile_m", {}).get("values", [1,2,4])
thread_tile_n_values = tile_config.get("thread_tile_n", {}).get("values", [1])
configs = []
for block_m in block_m_values:
for block_n in block_n_values:
for warp_m in warp_m_values:
for warp_n in warp_n_values:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for thread_tile_m in thread_tile_m_values:
for thread_tile_n in thread_tile_n_values:
if self._validate_tile_config(
block_m,
block_n,
warp_m,
warp_n,
warp_tile_m,
warp_tile_n,
thread_tile_m,
thread_tile_n,
fast_mode=fast_mode,
):
configs.append(
{
"block_m": block_m,
"block_n": block_n,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"thread_tile_m": thread_tile_m,
"thread_tile_n": thread_tile_n,
}
)
return configs
def _validate_tile_config(
self,
block_m,
block_n,
warp_m,
warp_n,
warp_tile_m,
warp_tile_n,
thread_tile_m,
thread_tile_n,
fast_mode=False,
):
"""Validate tile configuration via pooling_validation_utils."""
return is_tile_config_valid(
block_m,
block_n,
warp_m,
warp_n,
warp_tile_m,
warp_tile_n,
thread_tile_m,
thread_tile_n,
self.datatype,
self.datatype,
fast_mode=fast_mode,
)
def _generate_trait_combinations(self):
"""Generate all combinations of traits"""
if "trait_config" not in self.config:
return [("max", True, False, "2d")]
trait_config = self.config["trait_config"]
reduce_ops = trait_config.get("reduce_op", {}).get("values", ["min","max","avg"])
output_indices = trait_config.get("output_index", {}).get("values", [True, False])
propagate_nans = trait_config.get("propagate_nan", {}).get("values", [True, False])
pooling_dims = trait_config.get("pooling_dim", {}).get("values", ["2d", "3d"])
all_combinations = list(
itertools.product(reduce_ops, output_indices, propagate_nans, pooling_dims)
)
# Filter valid combinations
combinations = []
for combo in all_combinations:
reduce_op, output_index, propagate_nan, pooling_dim = combo
if is_trait_combination_valid(
reduce_op, output_index, propagate_nan, pooling_dim
):
combinations.append(combo)
else:
logger.debug(
f"Skipping unsupported trait combination: {reduce_op}-{output_index}-{propagate_nan}-{pooling_dim}"
)
return combinations
def _get_dtype_string(self):
"""Get C++ type string for datatype."""
return get_dtype_string(self.datatype)
def _get_reduce_op_string(self, reduce_op):
"""Get C++ reduce op type string."""
return get_reduce_op_string(reduce_op)
def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True):
"""Generate a single kernel instance header"""
reduce_op, output_index, propagate_nan, pooling_dim = trait_combo
# Create kernel name
kernel_name = (
f"pool_{self.datatype}_{pooling_dim}_{reduce_op}_"
f"{'idx' if output_index else 'noidx'}_"
f"{'nan' if propagate_nan else 'nonan'}"
)
# Create tile configuration string
tile_str = (
f"{tile_config['block_m']}x{tile_config['block_n']}_"
f"{tile_config['warp_m']}x{tile_config['warp_n']}_"
f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}_"
f"{tile_config['thread_tile_m']}x{tile_config['thread_tile_n']}"
)
kernel_name += f"_{tile_str}"
# Determine types
in_type = self._get_dtype_string()
out_type = in_type
compute_type = "float" # Always use float for computation
index_type = "ck_tile::index_t"
reduce_op_type = self._get_reduce_op_string(reduce_op)
output_index_str = "true" if output_index else "false"
propagate_nan_str = "true" if propagate_nan else "false"
# Generate 2D or 3D specific code
if pooling_dim == "2d":
tensor_shape_type = "ck_tile::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>"
window_shape_type = "ck_tile::tuple<ck_tile::index_t, ck_tile::index_t>"
window_rank = 2
else:
tensor_shape_type = "ck_tile::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>"
window_shape_type = (
"ck_tile::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>"
)
window_rank = 3
pragma_line = "#pragma once\n" if is_header else ""
instance_code = f"""// Generated kernel instance for {kernel_name}
{pragma_line}
#include <cstdint>
#include <utility>
#include <tuple>
#include <iostream>
#include <stdexcept>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/pooling.hpp"
using InDataType = {in_type};
using OutDataType = {out_type};
using ComputeDataType = {compute_type};
using IndexDataType = {index_type};
using ReduceOpType = {reduce_op_type};
using TensorShape = {tensor_shape_type};
using WindowShape = {window_shape_type};
// Kernel name for display
constexpr const char* KERNEL_NAME = "{kernel_name}";
constexpr int POOLING_DIM = {window_rank};
// Wrapper for simplified launch interface
struct SelectedKernel {{
// Tile configuration - PoolShape parameters
static constexpr ck_tile::index_t Block_M = {tile_config["block_m"]};
static constexpr ck_tile::index_t Block_N = {tile_config["block_n"]};
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
static constexpr ck_tile::index_t WarpTile_M = {tile_config["warp_tile_m"]};
static constexpr ck_tile::index_t WarpTile_N = {tile_config["warp_tile_n"]};
static constexpr ck_tile::index_t ThreadTile_M = {tile_config["thread_tile_m"]};
static constexpr ck_tile::index_t ThreadTile_N = {tile_config["thread_tile_n"]};
// Traits
static constexpr bool kOutputIndex = {output_index_str};
static constexpr bool kPropagateNan = {propagate_nan_str};
// Pool shape
using BlockWarps = ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N>;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using WarpTile = ck_tile::sequence<WarpTile_M, WarpTile_N>;
using ThreadTile = ck_tile::sequence<ThreadTile_M, ThreadTile_N>;
using PoolShapeType = ck_tile::PoolShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
// Problem and kernel types
using Problem = ck_tile::PoolProblem<InDataType,
OutDataType,
ComputeDataType,
IndexDataType,
ReduceOpType,
kOutputIndex,
kPropagateNan,
PoolShapeType>;
using Kernel = ck_tile::PoolKernel<Problem>;
static float launch(ck_tile::PoolHostArgs<TensorShape, WindowShape>& args,
const ck_tile::stream_config& stream) {{
constexpr ck_tile::index_t kBlockPerCu = 1;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
auto kernel_args = Kernel::MakeKernelArgs(args);
if (!Kernel::IsSupportedArgument(kernel_args)) {{
throw std::runtime_error(
std::string("Unsupported arguments for pooling kernel: ") + KERNEL_NAME);
}}
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
if(stream.log_level_ > 0) {{
std::cout << "Launching pooling kernel: " << KERNEL_NAME << "\\n"
<< " grid_size: " << kGridSize << ", block_size: " << kBlockSize
<< std::endl;
}}
return ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, kGridSize, kBlockSize, 0, kernel_args));
}}
}};
"""
return kernel_name, instance_code
def write_kernel_list(self):
"""Write kernel list to file for CMake to read"""
tile_configs = self._get_tile_configs(fast_mode=False)
trait_combos = self._generate_trait_combinations()
kernel_list = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
reduce_op, output_index, propagate_nan, pooling_dim = trait_combo
kernel_name = (
f"pool_{self.datatype}_{pooling_dim}_{reduce_op}_"
f"{'idx' if output_index else 'noidx'}_"
f"{'nan' if propagate_nan else 'nonan'}"
)
tile_str = (
f"{tile_config['block_m']}x{tile_config['block_n']}_"
f"{tile_config['warp_m']}x{tile_config['warp_n']}_"
f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}_"
f"{tile_config['thread_tile_m']}x{tile_config['thread_tile_n']}"
)
kernel_name += f"_{tile_str}"
trait_str = (
f"{reduce_op}_"
f"{'true' if output_index else 'false'}_"
f"{'true' if propagate_nan else 'false'}_"
f"{pooling_dim}"
)
kernel_list.append(
{
"name": kernel_name,
"tile_config": tile_config,
"trait_combo": trait_combo,
"tile_str": tile_str,
"trait_str": trait_str,
}
)
# Write kernel count
with open(self.working_path / "pool_kernel_count.txt", "w") as f:
f.write(str(len(kernel_list)))
# Write kernel list
with open(self.working_path / "pool_kernel_list.txt", "w") as f:
for kernel in kernel_list:
f.write(
f"{kernel['name']}|{kernel['tile_str']}|{kernel['trait_str']}\n"
)
print(f"Listed {len(kernel_list)} kernel configurations")
def generate_individual(self, num_workers=None):
"""Generate individual kernel files with parallel processing"""
if num_workers is None:
num_workers = min(multiprocessing.cpu_count(), 8)
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
self.working_path,
self.datatype,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 10 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
kernel_list.sort(key=lambda x: x[0])
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def run(self, num_workers=None):
"""Run the builder to generate individual kernel files"""
self.generate_individual(num_workers)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
tile_config, trait_combo, working_path, datatype = work_item
builder = PoolingKernelBuilder(working_path, datatype)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
header_file = working_path / f"pooling_single_{kernel_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(
description="Pooling kernel instance builder for tile_engine"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument(
"--datatype",
required=True,
choices=["fp8", "fp16", "bf16", "fp32"],
help="Data type",
)
parser.add_argument("--config_json", help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_individual", action="store_true", help="Generate individual kernel files"
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
builder = PoolingKernelBuilder(args.working_path, args.datatype, args.config_json)
if args.list_kernels:
builder.write_kernel_list()
elif args.gen_single:
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config: "block_mx block_n_warp_mxwarp_n_warp_tile_mxwarp_tile_n_thread_tile_mxthread_tile_n"
tile_parts = args.tile_config.split("_")
block_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
thread_tile_dims = tile_parts[3].split("x")
tile_config = {
"block_m": int(block_dims[0]),
"block_n": int(block_dims[1]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"thread_tile_m": int(thread_tile_dims[0]),
"thread_tile_n": int(thread_tile_dims[1]),
}
# Parse trait combo: "reduce_op_output_index_propagate_nan_pooling_dim"
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # reduce_op
trait_parts[1].lower() == "true", # output_index
trait_parts[2].lower() == "true", # propagate_nan
trait_parts[3], # pooling_dim
)
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
header_file = builder.working_path / f"pooling_single_{kernel_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
print(f"Generated {header_file}")
elif args.gen_individual:
builder.run(args.num_workers)
else:
parser.error(
"Must specify one of: --list_kernels, --gen_individual, or --gen_single"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,487 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Validation utilities for pooling tile_engine configurations.
Validates tile configurations, trait combinations, and datatype support for
pooling kernels. Modelled after gemm_validation_utils.py — each constraint
from the CK PoolShape / PoolKernel static_asserts is mirrored here so that
invalid configs are rejected at code-generation time rather than at compile
or runtime.
"""
import logging
from typing import List, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Hardware constants
# ---------------------------------------------------------------------------
# Default warp size (wave64 for CDNA architectures)
WARP_SIZE = 64
MAX_BLOCK_SIZE = 1024 # Maximum threads per workgroup on AMD GPUs
MAX_LDS_BYTES = 65536 # 64 KB LDS per workgroup
def get_warp_size_for_gpu(gpu_target: str) -> int:
"""Get the warp size for a given GPU target.
CDNA architectures (gfx9xx) use WAVE64 (64 threads per wavefront).
RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use WAVE32 (32 threads per wavefront).
"""
if gpu_target.startswith("gfx9"):
return 64 # CDNA - WAVE64
return 32 # RDNA and others - WAVE32
# ---------------------------------------------------------------------------
# Datatype helpers
# ---------------------------------------------------------------------------
ELEMENT_SIZE_MAP = {
"fp8": 1,
"bf8": 1,
"int8": 1,
"fp16": 2,
"bf16": 2,
"int4": 0.5,
"int32": 4,
"fp32": 4,
"fp64": 8,
}
DTYPE_STRING_MAP = {
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float",
"fp64": "double",
}
SUPPORTED_DATATYPES = list(DTYPE_STRING_MAP.keys())
# ---------------------------------------------------------------------------
# Reduce-op helpers
# ---------------------------------------------------------------------------
REDUCE_OP_STRING_MAP = {
"max": "ck_tile::ReduceOp::Max",
"min": "ck_tile::ReduceOp::Min",
"avg": "ck_tile::ReduceOp::Add",
}
SUPPORTED_REDUCE_OPS = list(REDUCE_OP_STRING_MAP.keys())
SUPPORTED_POOLING_DIMS = ("2d", "3d")
# ---------------------------------------------------------------------------
# Public helper functions (used by the instance builder)
# ---------------------------------------------------------------------------
def element_size(datatype: str) -> float:
"""Return the byte-width of a single element for *datatype*."""
datatype = datatype.lower()
if datatype not in ELEMENT_SIZE_MAP:
raise ValueError(
f"Unsupported data type: '{datatype}'. "
f"Supported: {list(ELEMENT_SIZE_MAP.keys())}"
)
return ELEMENT_SIZE_MAP[datatype]
def get_dtype_string(datatype: str) -> str:
"""Return the C++ type string (e.g. ``ck_tile::fp16_t``) for *datatype*."""
return DTYPE_STRING_MAP.get(datatype, "float")
def get_reduce_op_string(reduce_op: str) -> str:
"""Return the C++ ReduceOp enumerator string for *reduce_op*."""
return REDUCE_OP_STRING_MAP.get(reduce_op, "ck_tile::ReduceOp::Max")
# ---------------------------------------------------------------------------
# Individual tile-config validators
# ---------------------------------------------------------------------------
def validate_positivity(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""All tile parameters must be positive integers."""
params = {
"block_m": block_m,
"block_n": block_n,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"thread_tile_m": thread_tile_m,
"thread_tile_n": thread_tile_n,
}
for name, val in params.items():
if val <= 0:
return False, f"{name} ({val}) must be > 0"
return True, ""
def validate_power_of_two(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""All tile parameters should be powers of two for correct GPU addressing."""
params = {
"block_m": block_m,
"block_n": block_n,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"thread_tile_m": thread_tile_m,
"thread_tile_n": thread_tile_n,
}
for name, val in params.items():
if val > 0 and (val & (val - 1)) != 0:
return False, f"{name} ({val}) is not a power of two"
return True, ""
def validate_thread_tile_alignment(
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""
Mirrors pool_shape.hpp:
static_assert(Warp_M % ThreadTile_M == 0);
static_assert(Warp_N % ThreadTile_N == 0);
"""
if warp_tile_m % thread_tile_m != 0:
return (
False,
f"warp_tile_m ({warp_tile_m}) must be divisible by "
f"thread_tile_m ({thread_tile_m})",
)
if warp_tile_n % thread_tile_n != 0:
return (
False,
f"warp_tile_n ({warp_tile_n}) must be divisible by "
f"thread_tile_n ({thread_tile_n})",
)
return True, ""
def validate_warp_thread_distribution(
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[bool, str]:
"""
Mirrors pool_shape.hpp:
static_assert((Warp_M * Warp_N / ThreadTile_M / ThreadTile_N)
% get_warp_size() == 0);
"""
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
if threads_per_warp % warp_size != 0:
return (
False,
f"(warp_tile_m * warp_tile_n) / (thread_tile_m * thread_tile_n) = "
f"{threads_per_warp} is not a multiple of warp_size ({warp_size})",
)
return True, ""
def _compute_warp_size_scale_factors(
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[int, int]:
"""
Reproduce the WarpSizeScaleFactor_M / _N logic from pool_shape.hpp.
"""
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
scale = threads_per_warp // warp_size
if warp_tile_m // thread_tile_m > warp_tile_n // thread_tile_n:
return scale, 1
return 1, scale
def validate_block_tile_coverage(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[bool, str]:
"""
Mirrors pool_shape.hpp:
static_assert((Block_M * WarpSizeScaleFactor_M) %
(WarpPerBlock_M * Warp_M) == 0);
static_assert((Block_N * WarpSizeScaleFactor_N) %
(WarpPerBlock_N * Warp_N) == 0);
"""
sf_m, sf_n = _compute_warp_size_scale_factors(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n, warp_size
)
if (block_m * sf_m) % (warp_m * warp_tile_m) != 0:
return (
False,
f"block_m*ScaleFactor_M ({block_m}*{sf_m}={block_m * sf_m}) must be "
f"divisible by warp_m*warp_tile_m ({warp_m}*{warp_tile_m}"
f"={warp_m * warp_tile_m})",
)
if (block_n * sf_n) % (warp_n * warp_tile_n) != 0:
return (
False,
f"block_n*ScaleFactor_N ({block_n}*{sf_n}={block_n * sf_n}) must be "
f"divisible by warp_n*warp_tile_n ({warp_n}*{warp_tile_n}"
f"={warp_n * warp_tile_n})",
)
return True, ""
def validate_block_size(
warp_m: int,
warp_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[bool, str]:
"""BlockSize = warp_size * warp_m * warp_n must be <= MAX_BLOCK_SIZE."""
block_size = warp_size * warp_m * warp_n
if block_size > MAX_BLOCK_SIZE:
return (
False,
f"BlockSize ({block_size} = {warp_size}*{warp_m}*{warp_n}) "
f"exceeds maximum ({MAX_BLOCK_SIZE})",
)
return True, ""
def validate_vector_load_alignment(
block_m: int,
thread_tile_m: int,
in_datatype: str,
) -> Tuple[bool, str]:
"""
The M-dimension thread-tile determines the contiguous vector load width.
It must produce a load whose byte-width divides 16 bytes (max global
vector load width on AMD GPUs) and is at least 1 element wide.
"""
elem_bytes = element_size(in_datatype)
load_bytes = thread_tile_m * elem_bytes
if load_bytes > 16:
return (
False,
f"thread_tile_m ({thread_tile_m}) * element_size({in_datatype}, "
f"{elem_bytes}B) = {load_bytes}B exceeds 16B max vector load",
)
if 16 % load_bytes != 0 and load_bytes % 16 != 0:
return (
False,
f"Vector load width ({load_bytes}B) is not a divisor of 16B",
)
return True, ""
def validate_repeat_factors(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""
Repeat_M and Repeat_N from pool_shape.hpp must be >= 1. They are the
number of tile iterations each warp performs within the block.
"""
sf_m, sf_n = _compute_warp_size_scale_factors(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n
)
repeat_m = (block_m * sf_m) // (warp_m * warp_tile_m)
repeat_n = (block_n * sf_n) // (warp_n * warp_tile_n)
if repeat_m < 1:
return False, f"Repeat_M ({repeat_m}) must be >= 1"
if repeat_n < 1:
return False, f"Repeat_N ({repeat_n}) must be >= 1"
return True, ""
# ---------------------------------------------------------------------------
# Comprehensive tile-config validation (entry point)
# ---------------------------------------------------------------------------
def is_tile_config_valid(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
in_datatype: str,
out_datatype: str,
fast_mode: bool = False,
gpu_target: str = "gfx90a",
) -> bool:
"""
Comprehensive pooling tile configuration validation.
When *fast_mode* is True only cheap sanity checks are performed (useful
for the ``--list_kernels`` path). Full mode mirrors every
``static_assert`` in ``pool_shape.hpp``.
Parameters
----------
block_m, block_n : Block tile dimensions (M = output elems, N = window).
warp_m, warp_n : Warps per block along each dimension.
warp_tile_m, warp_tile_n : Tile processed per warp.
thread_tile_m, thread_tile_n : Contiguous elements per thread.
in_datatype : Input element type (e.g. ``"fp16"``).
out_datatype : Output element type.
fast_mode : Skip expensive checks when True.
"""
all_params = (
block_m, block_n, warp_m, warp_n,
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n,
)
# --- Positivity (always) ---
ok, err = validate_positivity(*all_params)
if not ok:
logger.debug(f"Positivity check failed: {err}")
return False
# --- Thread-tile alignment (always) ---
ok, err = validate_thread_tile_alignment(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n
)
if not ok:
logger.debug(f"Thread tile alignment failed: {err}")
return False
if fast_mode:
return True
# Get the warp size for this GPU target
warp_size = get_warp_size_for_gpu(gpu_target)
# --- Power-of-two ---
ok, err = validate_power_of_two(*all_params)
if not ok:
logger.debug(f"Power-of-two check failed: {err}")
return False
# --- Warp-thread distribution ---
ok, err = validate_warp_thread_distribution(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n, warp_size
)
if not ok:
logger.debug(f"Warp thread distribution failed: {err}")
return False
# --- Block-tile coverage ---
ok, err = validate_block_tile_coverage(*all_params, warp_size=warp_size)
if not ok:
logger.debug(f"Block tile coverage failed: {err}")
return False
# --- Block size ---
ok, err = validate_block_size(warp_m, warp_n, warp_size)
if not ok:
logger.debug(f"Block size check failed: {err}")
return False
# --- Repeat factors ---
ok, err = validate_repeat_factors(*all_params)
if not ok:
logger.debug(f"Repeat factor check failed: {err}")
return False
# --- Vector load alignment ---
ok, err = validate_vector_load_alignment(block_m, thread_tile_m, in_datatype)
if not ok:
logger.debug(f"Vector load alignment failed: {err}")
return False
return True
# ---------------------------------------------------------------------------
# Trait-combination validation
# ---------------------------------------------------------------------------
def is_trait_combination_valid(
reduce_op: str,
output_index: bool,
propagate_nan: bool,
pooling_dim: str,
) -> bool:
"""
Validate a pooling trait combination.
Parameters
----------
reduce_op : ``"max"``, ``"min"``, or ``"avg"``.
output_index : Whether to output indices of the selected elements.
propagate_nan: Whether to propagate NaN values through the reduction.
pooling_dim : ``"2d"`` or ``"3d"``.
"""
if reduce_op not in SUPPORTED_REDUCE_OPS:
logger.debug(f"Unsupported reduce_op: '{reduce_op}'")
return False
if pooling_dim not in SUPPORTED_POOLING_DIMS:
logger.debug(f"Invalid pooling dimension: '{pooling_dim}'")
return False
# output_index only makes sense for max pooling (CK constraint)
if output_index and reduce_op != "max":
logger.debug(
f"output_index=True is only supported for 'max' pooling, "
f"not '{reduce_op}'"
)
return False
return True
# ---------------------------------------------------------------------------
# Datatype validation
# ---------------------------------------------------------------------------
def is_datatype_supported(datatype: str) -> bool:
"""Return True if *datatype* is a known pooling datatype."""
return datatype.lower() in ELEMENT_SIZE_MAP

View File

@@ -11,7 +11,7 @@ set(MULTI_REDUCE_VARIANTS "multiops_multiblock;multiops_threadwise" CACHE STRING
function(build_multi_reduce_for_datatype datatype variant)
# Filter GPU targets to only gfx942, and gfx950
set(GPU_TARGETS "")
set(DESIRED_TARGETS "gfx942;gfx950")
set(DESIRED_TARGETS "gfx942;gfx950;gfx12-generic")
set(VALID_VARIANTS "multiops_multiblock;multiops_threadwise")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
@@ -22,7 +22,7 @@ function(build_multi_reduce_for_datatype datatype variant)
# Skip compilation if no matching targets found
if(NOT GPU_TARGETS)
message(WARNING "Skipping Tile Engine for Multi Reduction Kernel: No supported GPU targets (gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine for Multi Reduction Kernel: No supported GPU targets (gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()