[rocm-libraries] ROCm/rocm-libraries#4469 (commit 0844cb0)

[CK_TILE] Add pooling in tile_engine

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->
Add pooling in ck tile engine

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
aledudek
2026-04-01 07:32:36 +00:00
committed by assistant-librarian[bot]
parent 791afc6465
commit 119712bd90
25 changed files with 3258 additions and 19 deletions

View File

@@ -7,5 +7,6 @@ include_directories(BEFORE
add_subdirectory(ops/gemm)
add_subdirectory(ops/gemm_streamk)
add_subdirectory(ops/pooling)
add_subdirectory(ops/reduce)

View File

@@ -231,7 +231,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, gfx950
set(GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx12-generic")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -242,7 +242,7 @@ endforeach()
# Skip build if no matching targets found
if(NOT GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM Multi D build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine GEMM Multi D build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM Multi D targets for GPU targets: ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}")

View File

@@ -219,7 +219,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, and gfx950
set(GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx12-generic")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -230,7 +230,7 @@ endforeach()
# Skip build if no matching targets found
if(NOT GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM Preshuffle targets for GPU targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")

View File

@@ -226,7 +226,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
set(GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201;gfx12-generic")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -237,7 +237,7 @@ endforeach()
# Skip build if no matching targets found
if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM Universal targets for GPU targets: ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}")

View File

@@ -25,6 +25,16 @@ ELEMENT_SIZE_MAP = {
"fp64": 8,
}
def get_warp_size_for_gpu(gpu_target: str) -> int:
"""Get the warp size for a given GPU target.
CDNA architectures (gfx9xx) use WAVE64 (64 threads per wavefront).
RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use WAVE32 (32 threads per wavefront).
"""
if gpu_target.startswith("gfx9"):
return 64 # CDNA - WAVE64
return 32 # RDNA and others - WAVE32
WARP_SUPPORTED_COMBINATIONS = {
"gfx90a": [
[1, 4, 1],
@@ -586,10 +596,11 @@ def validate_whole_wg_cover_configuration(
layout,
a_datatype,
b_datatype,
gpu_target: str = "gfx90a",
) -> Tuple[bool, str]:
# Validate whole workgroup cover configuration
warp_size = 64
warp_size = get_warp_size_for_gpu(gpu_target)
NumWarps = warp_m * warp_n * warp_k
BlockSize = NumWarps * warp_size
@@ -704,6 +715,73 @@ def wg_cover_core_validation(
return True, ""
def validate_cshuffle_epilogue_distribution(
tile_m: int,
tile_n: int,
warp_m: int,
warp_n: int,
warp_k: int,
warp_tile_m: int,
warp_tile_n: int,
warp_size: int,
c_datatype: str,
) -> Tuple[bool, str]:
"""
Validate that the CShuffleEpilogue tile distribution pattern is valid.
This mirrors the static_assert in static_encoding_pattern.hpp:
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
The CShuffleEpilogue creates a tile_distribution_encoding_pattern_2d<BlockSize, YPerTile, XPerTile, VecSize, thread_raked>
where:
- BlockSize = warp_m * warp_n * warp_k * warp_size
- YPerTile = MPerIterationShuffle (derived from tile_m / (warp_m * warp_tile_m / some_factor))
- XPerTile = NPerIterationShuffle (derived from tile_n)
- VecSize = vector size based on element size (typically 8 for fp16)
The key constraint is that X0 must evenly divide warp_size, where:
- X0 = min(warp_size, XPerTile / X1)
- X1 = min(VecSize, LargestVec)
- LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
"""
NumWarps = warp_m * warp_n * warp_k
BlockSize = NumWarps * warp_size
elem_size = ELEMENT_SIZE_MAP.get(c_datatype, 2)
VecSize = 16 // elem_size
XPerTile = tile_n
YPerTile = tile_m // warp_m
if XPerTile <= 0 or YPerTile <= 0:
return False, f"Invalid tile dimensions: XPerTile={XPerTile}, YPerTile={YPerTile}"
num_warps = BlockSize // warp_size
if num_warps * warp_size == 0:
return False, "Invalid BlockSize or warp_size"
LargestVec = (XPerTile * YPerTile) // (num_warps * warp_size)
if LargestVec <= 0:
LargestVec = 1
X1 = min(VecSize, LargestVec) if LargestVec > 0 else VecSize
if X1 <= 0:
X1 = 1
X0 = min(warp_size, XPerTile // X1) if X1 > 0 else warp_size
Y1 = warp_size // X0 if X0 > 0 else 0
if X0 * Y1 != warp_size:
return (
False,
f"CShuffleEpilogue distribution invalid: X0({X0}) * Y1({Y1}) = {X0 * Y1} != warp_size({warp_size}). "
f"XPerTile={XPerTile}, YPerTile={YPerTile}, VecSize={VecSize}, BlockSize={BlockSize}"
)
return True, ""
def get_global_vector_load_size(
BlockSize: int,
KPerBlock: int,
@@ -766,6 +844,8 @@ def validate_gemm(
trait_name: str = None,
) -> bool:
# GEMM Validation
warp_size = get_warp_size_for_gpu(gpu_target)
# Validate whole workgroup cover configuration
whole_workgroup_cover_valid, whole_workgroup_cover_error = (
validate_whole_wg_cover_configuration(
@@ -778,6 +858,7 @@ def validate_gemm(
layout,
a_datatype,
b_datatype,
gpu_target,
)
)
if not whole_workgroup_cover_valid:
@@ -786,6 +867,23 @@ def validate_gemm(
)
return False, whole_workgroup_cover_error
# Validate CShuffleEpilogue distribution pattern (for cshuffle epilogue)
# This validation ensures the tile distribution pattern is valid for the output tile
cshuffle_valid, cshuffle_error = validate_cshuffle_epilogue_distribution(
tile_m,
tile_n,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_size,
c_datatype,
)
if not cshuffle_valid:
logging.debug(f"CShuffleEpilogue validation failed: {cshuffle_error}")
return False, cshuffle_error
return True, ""
@@ -808,6 +906,8 @@ def validate_gemm_preshuffle(
trait_name: str = None,
) -> bool:
# Preshuffle Validations
warp_size = get_warp_size_for_gpu(gpu_target)
# Validate vector load alignment
m_iter_per_warp = tile_m / (warp_m * warp_tile_m)
vector_valid, vector_error = validate_vector_load_alignment(
@@ -815,7 +915,7 @@ def validate_gemm_preshuffle(
warp_tile_k,
a_datatype,
m_iter_per_warp,
wave_size=64,
wave_size=warp_size,
vector_load_size=16,
)
if not vector_valid:
@@ -831,7 +931,7 @@ def validate_gemm_preshuffle(
warp_k,
a_datatype,
vector_load_size=16,
warp_size=64,
warp_size=warp_size,
)
if not m0_m1_m2_valid:
logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}")

View File

@@ -226,7 +226,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx942, gfx950
set(GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx942;gfx950")
set(DESIRED_TARGETS "gfx942;gfx950;gfx12-generic")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -237,7 +237,7 @@ endforeach()
# Skip build if no matching targets found
if(NOT GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine Grouped GEMM build: No supported GPU targets (gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine Grouped GEMM build: No supported GPU targets (gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual Grouped GEMM targets for GPU targets: ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}")

View File

@@ -216,7 +216,7 @@ message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942") # TODO: Add gfx950 when supported
set(DESIRED_TARGETS "gfx90a;gfx942;gfx12-generic") # TODO: Add gfx950 when supported
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -227,7 +227,7 @@ endforeach()
# Skip build if no matching targets found
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")

View File

@@ -0,0 +1,212 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ============================================================================
# Pooling Tile Engine Build Configuration
#
# Generates individual benchmark executables for pooling kernels
# ============================================================================
set(POOLING_DATATYPE "fp8;fp16;fp32" CACHE STRING "List of datatypes for Pooling (semicolon-separated)")
set(POOLING_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
option(ENABLE_CCACHE_POOLING "Enable ccache for pooling ops compilation" OFF)
# Store the directory path for use in functions
set(POOLING_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
# ============================================================================
# create_individual_pool_target
#
# Creates a single benchmark executable for a specific pooling kernel config.
# ============================================================================
function(create_individual_pool_target datatype kernel_name trait tile_config config_json)
if(NOT POOLING_GPU_TARGETS)
message(WARNING "Skipping individual pooling target: No supported GPU targets")
return()
endif()
set(target_name "benchmark_pooling_${datatype}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}")
# HIP clang offload uses temporary files derived from the input source basename.
# When many targets compile the same source filename in parallel, temporary
# files can collide and corrupt each other. Use a unique copied source per target.
set(target_source "${CMAKE_CURRENT_BINARY_DIR}/${target_name}_pooling_benchmark_single.cpp")
# Generated header path - use kernel_name from pool_kernel_list.txt to match
# the filename generated by pooling_instance_builder.py
set(instance_header "${working_path}/pooling_single_${kernel_name}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
OUTPUT ${instance_header}
COMMAND ${Python3_EXECUTABLE} ${POOLING_SOURCE_DIR}/pooling_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--config_json ${config_json}
--gen_single
--kernel_name "${kernel_name}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
DEPENDS ${POOLING_SOURCE_DIR}/pooling_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
configure_file(${POOLING_SOURCE_DIR}/pooling_benchmark_single.cpp ${target_source} COPYONLY)
# Create the executable
add_executable(${target_name}
${target_source}
${instance_header}
)
# Set GPU architectures
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${POOLING_GPU_TARGETS})
# Set compile definitions
target_compile_definitions(${target_name} PRIVATE
POOLING_SINGLE_INSTANCE_HPP="${instance_header}"
)
# Include directories
target_include_directories(${target_name} PRIVATE
${POOLING_SOURCE_DIR}
${working_path}
)
# Compile options
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
-include ${instance_header}
)
# Add FP8 format definitions if needed
if(CK_USE_OCP_FP8)
target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8)
endif()
# Add to collection targets
add_dependencies(benchmark_pooling_all ${target_name})
add_dependencies(benchmark_pooling_${datatype} ${target_name})
message(STATUS " Created pooling benchmark target: ${target_name}")
endfunction()
# ============================================================================
# build_individual_pool_targets
#
# Builds all benchmark targets for a specific datatype.
# ============================================================================
function(build_individual_pool_targets datatype)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}")
# Choose config file
if(DEFINED ENV{POOLING_CONFIG_FILE} AND NOT "$ENV{POOLING_CONFIG_FILE}" STREQUAL "")
set(config_filename "$ENV{POOLING_CONFIG_FILE}")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
message(STATUS " Using config from environment variable: ${config_filename}")
elseif(NOT "${POOLING_CONFIG_FILE}" STREQUAL "")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${POOLING_CONFIG_FILE}")
message(STATUS " Using custom config: ${POOLING_CONFIG_FILE}")
else()
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
message(STATUS " Using default config for pooling")
endif()
if(NOT EXISTS ${json_blob})
message(FATAL_ERROR "Config file not found: ${json_blob}")
endif()
file(MAKE_DIRECTORY ${working_path})
# Step 1: List kernels
message(STATUS " Listing pooling kernel configurations for ${datatype}...")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/pooling_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--config_json ${json_blob}
--list_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list pooling kernels for ${datatype}: ${list_error}")
endif()
# Read kernel count
if(EXISTS ${working_path}/pool_kernel_count.txt)
file(READ ${working_path}/pool_kernel_count.txt kernel_count)
string(STRIP "${kernel_count}" kernel_count)
message(STATUS " Found ${kernel_count} pooling kernel configurations")
else()
message(FATAL_ERROR "Pooling kernel count file not found")
endif()
# Step 2: Create targets
if(EXISTS ${working_path}/pool_kernel_list.txt)
file(STRINGS ${working_path}/pool_kernel_list.txt kernel_lines)
foreach(line IN LISTS kernel_lines)
string(REPLACE "|" ";" parts "${line}")
list(LENGTH parts parts_len)
if(parts_len EQUAL 3)
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
create_individual_pool_target("${datatype}" "${kernel_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
endif()
endforeach()
else()
message(FATAL_ERROR "Pooling kernel list file not found")
endif()
endfunction()
# ============================================================================
# MAIN EXECUTION
# ============================================================================
message(STATUS "=== Starting Tile Engine Pooling Configuration ===")
message(STATUS "POOLING_DATATYPE: ${POOLING_DATATYPE}")
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets
set(POOLING_GPU_TARGETS "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND POOLING_GPU_TARGETS ${target})
message(STATUS " Adding GPU target for pooling: ${target}")
endif()
endforeach()
if(NOT POOLING_GPU_TARGETS)
message(WARNING "Skipping Tile Engine Pooling build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(STATUS "Building pooling targets for GPU targets: ${POOLING_GPU_TARGETS}")
# Enable ccache if requested
if(ENABLE_CCACHE_POOLING)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(STATUS "Using ccache for pooling compilation")
endif()
endif()
# Create collection targets
add_custom_target(benchmark_pooling_all)
foreach(dt IN LISTS POOLING_DATATYPE)
add_custom_target(benchmark_pooling_${dt})
endforeach()
# Build targets for each datatype
foreach(dt IN LISTS POOLING_DATATYPE)
build_individual_pool_targets(${dt})
endforeach()
endif()

View File

@@ -0,0 +1,21 @@
{
"problem": {
"description": "Default pooling configuration for tile_engine benchmarks"
},
"tile_config": {
"block_m": {"values": [64,128,256]},
"block_n": {"values": [1,2]},
"warp_m": {"values": [1]},
"warp_n": {"values": [1]},
"warp_tile_m": {"values": [128]},
"warp_tile_n": {"values": [1]},
"thread_tile_m": {"values": [1,2,4]},
"thread_tile_n": {"values": [1]}
},
"trait_config": {
"reduce_op": {"values": ["max", "min", "avg"]},
"output_index": {"values": [true, false]},
"propagate_nan": {"values": [true, false]},
"pooling_dim": {"values": ["2d", "3d"]}
}
}

View File

@@ -0,0 +1,132 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include <vector>
#include <numeric>
#include <algorithm>
#include <iostream>
#include <sstream>
#include <cmath>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/pooling.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
namespace ck_tile {
/// @brief Performance metrics for benchmarking
enum class PoolMetric
{
LATENCY,
BANDWIDTH
};
/// @brief Pooling problem specification for 2D pooling
struct PoolProblem2D
{
index_t N, H, W, C; // Input dimensions (NHWC)
index_t Y, X; // Window dimensions
index_t stride_h, stride_w; // Window strides
index_t dilation_h, dilation_w; // Window dilations
index_t pad_h_left, pad_h_right; // Height padding
index_t pad_w_left, pad_w_right; // Width padding
std::string datatype; // Data type name
std::string reduce_op; // "max", "min", or "avg"
index_t Ho() const
{
index_t Ys = (Y - 1) * dilation_h + 1;
return (H + pad_h_left + pad_h_right - Ys) / stride_h + 1;
}
index_t Wo() const
{
index_t Xs = (X - 1) * dilation_w + 1;
return (W + pad_w_left + pad_w_right - Xs) / stride_w + 1;
}
index_t input_elements() const { return N * H * W * C; }
index_t output_elements() const { return N * Ho() * Wo() * C; }
std::string to_string() const
{
std::ostringstream oss;
oss << "N" << N << "_H" << H << "_W" << W << "_C" << C << "_Y" << Y << "_X" << X << "_Sh"
<< stride_h << "_Sw" << stride_w << "_Dh" << dilation_h << "_Dw" << dilation_w;
if(pad_h_left > 0 || pad_w_left > 0)
oss << "_Ph" << pad_h_left << "_Pw" << pad_w_left;
return oss.str();
}
};
/// @brief Pooling problem specification for 3D pooling
struct PoolProblem3D
{
index_t N, D, H, W, C; // Input dimensions (NDHWC)
index_t Z, Y, X; // Window dimensions
index_t stride_d, stride_h, stride_w; // Window strides
index_t dilation_d, dilation_h, dilation_w; // Window dilations
index_t pad_d_left, pad_d_right; // Depth padding
index_t pad_h_left, pad_h_right; // Height padding
index_t pad_w_left, pad_w_right; // Width padding
std::string datatype; // Data type name
std::string reduce_op; // "max", "min", or "avg"
index_t Do() const
{
index_t Zs = (Z - 1) * dilation_d + 1;
return (D + pad_d_left + pad_d_right - Zs) / stride_d + 1;
}
index_t Ho() const
{
index_t Ys = (Y - 1) * dilation_h + 1;
return (H + pad_h_left + pad_h_right - Ys) / stride_h + 1;
}
index_t Wo() const
{
index_t Xs = (X - 1) * dilation_w + 1;
return (W + pad_w_left + pad_w_right - Xs) / stride_w + 1;
}
index_t input_elements() const { return N * D * H * W * C; }
index_t output_elements() const { return N * Do() * Ho() * Wo() * C; }
std::string to_string() const
{
std::ostringstream oss;
oss << "N" << N << "_D" << D << "_H" << H << "_W" << W << "_C" << C << "_Z" << Z << "_Y"
<< Y << "_X" << X;
return oss.str();
}
};
/// @brief Performance result for a pooling kernel
struct PoolPerformanceResult
{
float latency_ms;
float bandwidth_gb_s;
std::string to_string() const
{
std::ostringstream oss;
oss << "latency=" << latency_ms << "ms, bandwidth=" << bandwidth_gb_s << "GB/s";
return oss.str();
}
};
/// @brief Benchmark settings
struct PoolBenchmarkSetting
{
int warmup = 5;
int repeat = 20;
bool verify = true;
int init_method = 0; // 0: uniform random, 1: integer sequence, 2: constant, 3: special
};
} // namespace ck_tile

View File

@@ -0,0 +1,390 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* @file pooling_benchmark_single.cpp
* @brief Single-kernel benchmark for pooling operations (2D and 3D).
*
* This benchmark includes the generated kernel header via -include flag
* and runs the pooling kernel with specified problem sizes.
*
* The generated header provides:
* - SelectedKernel (struct with ::launch())
* - KERNEL_NAME (constexpr const char*)
* - POOLING_DIM (constexpr int, 2 or 3)
* - InDataType, OutDataType, ComputeDataType, IndexDataType, ReduceOpType
* - TensorShape, WindowShape
*/
#include <iostream>
#include <string>
#include <vector>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/pooling.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
#include "pooling_common.hpp"
#include "pooling_benchmark.hpp"
// The kernel header is included via compile command line with -include flag.
// --------------------------------------------------------------------------
// Benchmark implementation — templated on pooling dimension so that only
// the matching branch is instantiated (2D or 3D).
// --------------------------------------------------------------------------
template <typename HostArgs>
static float launch_selected_kernel(HostArgs& args, const ck_tile::stream_config& stream)
{
return SelectedKernel::launch(args, stream);
}
template <int PoolDim>
static int benchmark_pooling(int argc, char* argv[])
{
if constexpr(PoolDim == 2)
{
// ---- 2D argument parser ----
ck_tile::ArgParser arg_parser;
arg_parser.insert("n", "1", "Batch size (N)")
.insert("h", "16", "Input height (H)")
.insert("w", "16", "Input width (W)")
.insert("c", "32", "Channels (C)")
.insert("wy", "2", "Window height (Y)")
.insert("wx", "2", "Window width (X)")
.insert("sy", "2", "Window stride height")
.insert("sx", "2", "Window stride width")
.insert("dy", "1", "Window dilation height")
.insert("dx", "1", "Window dilation width")
.insert("phy", "0", "Padding height left")
.insert("phyr", "0", "Padding height right")
.insert("pwx", "0", "Padding width left")
.insert("pwxr", "0", "Padding width right")
.insert("verify", "1", "Verify results (0/1)")
.insert("warmup", "5", "Warmup iterations")
.insert("repeat", "20", "Repeat iterations")
.insert("log", "1", "Log level");
if(!arg_parser.parse(argc, argv))
return -1;
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t H = arg_parser.get_int("h");
ck_tile::index_t W = arg_parser.get_int("w");
ck_tile::index_t C = arg_parser.get_int("c");
ck_tile::index_t Y = arg_parser.get_int("wy");
ck_tile::index_t X = arg_parser.get_int("wx");
ck_tile::index_t Sy = arg_parser.get_int("sy");
ck_tile::index_t Sx = arg_parser.get_int("sx");
ck_tile::index_t Dy = arg_parser.get_int("dy");
ck_tile::index_t Dx = arg_parser.get_int("dx");
ck_tile::index_t LeftPy = arg_parser.get_int("phy");
ck_tile::index_t RightPy = arg_parser.get_int("phyr");
ck_tile::index_t LeftPx = arg_parser.get_int("pwx");
ck_tile::index_t RightPx = arg_parser.get_int("pwxr");
bool verify = arg_parser.get_int("verify") != 0;
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int log_level = arg_parser.get_int("log");
ck_tile::index_t Ys = (Y - 1) * Dy + 1;
ck_tile::index_t Xs = (X - 1) * Dx + 1;
ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
std::cout << "Pooling 2D benchmark: " << KERNEL_NAME << std::endl;
std::cout << " Input: NHWC = " << N << "x" << H << "x" << W << "x" << C << std::endl;
std::cout << " Output: NHWC = " << N << "x" << Ho << "x" << Wo << "x" << C << std::endl;
std::cout << " Window: " << Y << "x" << X << ", stride: " << Sy << "x" << Sx
<< ", dilation: " << Dy << "x" << Dx << std::endl;
ck_tile::HostTensor<InDataType> h_in({N, H, W, C});
ck_tile::HostTensor<OutDataType> h_out({N, Ho, Wo, C});
ck_tile::HostTensor<OutDataType> h_out_ref({N, Ho, Wo, C});
ck_tile::HostTensor<IndexDataType> h_out_index({N, Ho, Wo, C});
ck_tile::HostTensor<IndexDataType> h_out_ref_index({N, Ho, Wo, C});
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
ck_tile::DeviceMem d_in(h_in.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out(h_out.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out_index(h_out_index.get_element_space_size_in_bytes());
d_in.ToDevice(h_in.data());
d_out.SetZero();
d_out_index.SetZero();
auto input_shape = ck_tile::make_tuple(N, H, W, C);
auto output_shape = ck_tile::make_tuple(N, Ho, Wo, C);
auto input_strides = ck_tile::make_tuple(H * W * C, W * C, C, ck_tile::index_t{1});
auto output_strides = ck_tile::make_tuple(Ho * Wo * C, Wo * C, C, ck_tile::index_t{1});
auto window_lengths = ck_tile::make_tuple(Y, X);
auto window_strides = ck_tile::make_tuple(Sy, Sx);
auto window_dilations = ck_tile::make_tuple(Dy, Dx);
auto input_left_pads = ck_tile::make_tuple(LeftPy, LeftPx);
auto input_right_pads = ck_tile::make_tuple(RightPy, RightPx);
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
d_in.GetDeviceBuffer(),
d_out.GetDeviceBuffer(),
d_out_index.GetDeviceBuffer(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::stream_config stream{nullptr, true, log_level, warmup, repeat};
float latency = 0;
try
{
latency = launch_selected_kernel(host_args, stream);
}
catch(const std::exception& e)
{
std::cerr << "Kernel launch failed: " << e.what() << std::endl;
return -1;
}
size_t bytes_read = static_cast<size_t>(N) * H * W * C * sizeof(InDataType);
size_t bytes_written = static_cast<size_t>(N) * Ho * Wo * C * sizeof(OutDataType);
float bandwidth = (bytes_read + bytes_written) / (latency * 1e-3f) / 1e9f;
std::cout << " Latency: " << latency << " ms" << std::endl;
std::cout << " Bandwidth: " << bandwidth << " GB/s" << std::endl;
if(verify)
{
d_out.FromDevice(h_out.data());
d_out_index.FromDevice(h_out_index.data());
auto kernel_args =
ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
h_in.data(),
h_out_ref.data(),
h_out_ref_index.data(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::reference_pool2d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOpType,
decltype(input_shape),
decltype(window_lengths),
SelectedKernel::kOutputIndex>(
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
bool pass_value =
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-3, 1e-3);
std::cout << " Verification: " << (pass_value ? "PASS" : "FAIL") << std::endl;
if(SelectedKernel::kOutputIndex)
{
bool pass_index = ck_tile::check_err(
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
std::cout << " Index verification: " << (pass_index ? "PASS" : "FAIL")
<< std::endl;
}
}
return 0;
}
else // PoolDim == 3
{
// ---- 3D argument parser ----
ck_tile::ArgParser arg_parser;
arg_parser.insert("n", "1", "Batch size (N)")
.insert("d", "4", "Input depth (D)")
.insert("h", "16", "Input height (H)")
.insert("w", "16", "Input width (W)")
.insert("c", "32", "Channels (C)")
.insert("wz", "2", "Window depth (Z)")
.insert("wy", "2", "Window height (Y)")
.insert("wx", "2", "Window width (X)")
.insert("sz", "2", "Window stride depth")
.insert("sy", "2", "Window stride height")
.insert("sx", "2", "Window stride width")
.insert("dz", "1", "Window dilation depth")
.insert("dy", "1", "Window dilation height")
.insert("dx", "1", "Window dilation width")
.insert("pdz", "0", "Padding depth left")
.insert("pdzr", "0", "Padding depth right")
.insert("phy", "0", "Padding height left")
.insert("phyr", "0", "Padding height right")
.insert("pwx", "0", "Padding width left")
.insert("pwxr", "0", "Padding width right")
.insert("verify", "1", "Verify results (0/1)")
.insert("warmup", "5", "Warmup iterations")
.insert("repeat", "20", "Repeat iterations")
.insert("log", "1", "Log level");
if(!arg_parser.parse(argc, argv))
return -1;
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t D = arg_parser.get_int("d");
ck_tile::index_t H = arg_parser.get_int("h");
ck_tile::index_t W = arg_parser.get_int("w");
ck_tile::index_t C = arg_parser.get_int("c");
ck_tile::index_t Z = arg_parser.get_int("wz");
ck_tile::index_t Y = arg_parser.get_int("wy");
ck_tile::index_t X = arg_parser.get_int("wx");
ck_tile::index_t Sz = arg_parser.get_int("sz");
ck_tile::index_t Sy = arg_parser.get_int("sy");
ck_tile::index_t Sx = arg_parser.get_int("sx");
ck_tile::index_t Dz = arg_parser.get_int("dz");
ck_tile::index_t Dy = arg_parser.get_int("dy");
ck_tile::index_t Dx = arg_parser.get_int("dx");
ck_tile::index_t LeftPz = arg_parser.get_int("pdz");
ck_tile::index_t RightPz = arg_parser.get_int("pdzr");
ck_tile::index_t LeftPy = arg_parser.get_int("phy");
ck_tile::index_t RightPy = arg_parser.get_int("phyr");
ck_tile::index_t LeftPx = arg_parser.get_int("pwx");
ck_tile::index_t RightPx = arg_parser.get_int("pwxr");
bool verify = arg_parser.get_int("verify") != 0;
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int log_level = arg_parser.get_int("log");
ck_tile::index_t Zs = (Z - 1) * Dz + 1;
ck_tile::index_t Ys = (Y - 1) * Dy + 1;
ck_tile::index_t Xs = (X - 1) * Dx + 1;
ck_tile::index_t Do = (D + LeftPz + RightPz - Zs) / Sz + 1;
ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
std::cout << "Pooling 3D benchmark: " << KERNEL_NAME << std::endl;
std::cout << " Input: NDHWC = " << N << "x" << D << "x" << H << "x" << W << "x" << C
<< std::endl;
std::cout << " Output: NDHWC = " << N << "x" << Do << "x" << Ho << "x" << Wo << "x" << C
<< std::endl;
std::cout << " Window: " << Z << "x" << Y << "x" << X << ", stride: " << Sz << "x" << Sy
<< "x" << Sx << ", dilation: " << Dz << "x" << Dy << "x" << Dx << std::endl;
ck_tile::HostTensor<InDataType> h_in({N, D, H, W, C});
ck_tile::HostTensor<OutDataType> h_out({N, Do, Ho, Wo, C});
ck_tile::HostTensor<OutDataType> h_out_ref({N, Do, Ho, Wo, C});
ck_tile::HostTensor<IndexDataType> h_out_index({N, Do, Ho, Wo, C});
ck_tile::HostTensor<IndexDataType> h_out_ref_index({N, Do, Ho, Wo, C});
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
ck_tile::DeviceMem d_in(h_in.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out(h_out.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out_index(h_out_index.get_element_space_size_in_bytes());
d_in.ToDevice(h_in.data());
d_out.SetZero();
d_out_index.SetZero();
auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
auto input_strides =
ck_tile::make_tuple(D * H * W * C, H * W * C, W * C, C, ck_tile::index_t{1});
auto output_strides =
ck_tile::make_tuple(Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, ck_tile::index_t{1});
auto window_lengths = ck_tile::make_tuple(Z, Y, X);
auto window_strides = ck_tile::make_tuple(Sz, Sy, Sx);
auto window_dilations = ck_tile::make_tuple(Dz, Dy, Dx);
auto input_left_pads = ck_tile::make_tuple(LeftPz, LeftPy, LeftPx);
auto input_right_pads = ck_tile::make_tuple(RightPz, RightPy, RightPx);
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
d_in.GetDeviceBuffer(),
d_out.GetDeviceBuffer(),
d_out_index.GetDeviceBuffer(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::stream_config stream{nullptr, true, log_level, warmup, repeat};
float latency = 0;
try
{
latency = launch_selected_kernel(host_args, stream);
}
catch(const std::exception& e)
{
std::cerr << "Kernel launch failed: " << e.what() << std::endl;
return -1;
}
size_t bytes_read = static_cast<size_t>(N) * D * H * W * C * sizeof(InDataType);
size_t bytes_written = static_cast<size_t>(N) * Do * Ho * Wo * C * sizeof(OutDataType);
float bandwidth = (bytes_read + bytes_written) / (latency * 1e-3f) / 1e9f;
std::cout << " Latency: " << latency << " ms" << std::endl;
std::cout << " Bandwidth: " << bandwidth << " GB/s" << std::endl;
if(verify)
{
d_out.FromDevice(h_out.data());
d_out_index.FromDevice(h_out_index.data());
auto kernel_args =
ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
h_in.data(),
h_out_ref.data(),
h_out_ref_index.data(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::reference_pool3d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOpType,
decltype(input_shape),
decltype(window_lengths),
SelectedKernel::kOutputIndex>(
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
bool pass_value =
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-3, 1e-3);
std::cout << " Verification: " << (pass_value ? "PASS" : "FAIL") << std::endl;
if(SelectedKernel::kOutputIndex)
{
bool pass_index = ck_tile::check_err(
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
std::cout << " Index verification: " << (pass_index ? "PASS" : "FAIL")
<< std::endl;
}
}
return 0;
}
}
int main(int argc, char* argv[]) { return benchmark_pooling<POOLING_DIM>(argc, argv); }

View File

@@ -0,0 +1,52 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include <sstream>
#include <iostream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/pooling.hpp"
namespace ck_tile {
/// @brief Kernel trait parameters for pooling tile_engine configurations
struct PoolingKernelTraits
{
std::string reduce_op; // "max", "min", or "avg"
bool output_index; // Whether to output indices (max pooling)
bool propagate_nan; // Whether to propagate NaN values
bool cross_warp; // Whether cross-warp reduction is used
std::string to_string() const
{
std::ostringstream oss;
oss << reduce_op << "_" << (output_index ? "idx" : "noidx") << "_"
<< (propagate_nan ? "nan" : "nonan") << "_"
<< (cross_warp ? "crosswarp" : "nocrosswarp");
return oss.str();
}
};
/// @brief Extract traits from a kernel name string
inline PoolingKernelTraits extract_pooling_traits_from_name(const std::string& name)
{
PoolingKernelTraits traits;
if(name.find("max") != std::string::npos)
traits.reduce_op = "max";
else if(name.find("min") != std::string::npos)
traits.reduce_op = "min";
else
traits.reduce_op = "avg";
traits.output_index =
(name.find("idx") != std::string::npos) && (name.find("noidx") == std::string::npos);
traits.propagate_nan =
(name.find("nan") != std::string::npos) && (name.find("nonan") == std::string::npos);
traits.cross_warp = (name.find("crosswarp") != std::string::npos) &&
(name.find("nocrosswarp") == std::string::npos);
return traits;
}
} // namespace ck_tile

View File

@@ -0,0 +1,551 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Pooling kernel instance builder for tile_engine.
Generates C++ kernel headers for pooling operations with specific tile
configurations and trait combinations.
Usage:
--list_kernels: List valid kernel configurations
--gen_single: Generate a single kernel header
--gen_individual: Generate all kernel headers
"""
import os
import json
import argparse
import itertools
import multiprocessing
import concurrent.futures
from pathlib import Path
import logging
from pooling_validation_utils import (
is_tile_config_valid,
is_trait_combination_valid,
get_dtype_string,
get_reduce_op_string,
)
logger = logging.getLogger(__name__)
class PoolingKernelBuilder:
def __init__(self, working_path, datatype, config_json=None):
self.working_path = Path(working_path)
self.datatype = datatype
self.config_json = config_json
# Create working directory if it doesn't exist
self.working_path.mkdir(parents=True, exist_ok=True)
# Load configuration
if config_json and os.path.exists(config_json):
with open(config_json, "r") as f:
self.config = json.load(f)
else:
self.config = self._get_default_config()
def _get_default_config(self):
"""Return default configuration if no config file is provided"""
return {
"tile_config": {
"block_m": {"values": [64,128,256]},
"block_n": {"values": [1,2]},
"warp_m": {"values": [1]},
"warp_n": {"values": [1]},
"warp_tile_m": {"values": [128]},
"warp_tile_n": {"values": [1]},
"thread_tile_m": {"values": [1,2,4]},
"thread_tile_n": {"values": [1]},
},
"trait_config": {
"reduce_op": {"values": ["max", "min", "avg"]},
"output_index": {"values": [True, False]},
"propagate_nan": {"values": [True, False]},
"pooling_dim": {"values": ["2d", "3d"]},
},
}
def _get_tile_configs(self, fast_mode=False):
"""Get tile configurations from config"""
if "tile_config" not in self.config:
return []
tile_config = self.config["tile_config"]
block_m_values = tile_config.get("block_m", {}).get("values", [64,128,256])
block_n_values = tile_config.get("block_n", {}).get("values", [1,2])
warp_m_values = tile_config.get("warp_m", {}).get("values", [1])
warp_n_values = tile_config.get("warp_n", {}).get("values", [1])
warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [128])
warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [1])
thread_tile_m_values = tile_config.get("thread_tile_m", {}).get("values", [1,2,4])
thread_tile_n_values = tile_config.get("thread_tile_n", {}).get("values", [1])
configs = []
for block_m in block_m_values:
for block_n in block_n_values:
for warp_m in warp_m_values:
for warp_n in warp_n_values:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for thread_tile_m in thread_tile_m_values:
for thread_tile_n in thread_tile_n_values:
if self._validate_tile_config(
block_m,
block_n,
warp_m,
warp_n,
warp_tile_m,
warp_tile_n,
thread_tile_m,
thread_tile_n,
fast_mode=fast_mode,
):
configs.append(
{
"block_m": block_m,
"block_n": block_n,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"thread_tile_m": thread_tile_m,
"thread_tile_n": thread_tile_n,
}
)
return configs
def _validate_tile_config(
self,
block_m,
block_n,
warp_m,
warp_n,
warp_tile_m,
warp_tile_n,
thread_tile_m,
thread_tile_n,
fast_mode=False,
):
"""Validate tile configuration via pooling_validation_utils."""
return is_tile_config_valid(
block_m,
block_n,
warp_m,
warp_n,
warp_tile_m,
warp_tile_n,
thread_tile_m,
thread_tile_n,
self.datatype,
self.datatype,
fast_mode=fast_mode,
)
def _generate_trait_combinations(self):
"""Generate all combinations of traits"""
if "trait_config" not in self.config:
return [("max", True, False, "2d")]
trait_config = self.config["trait_config"]
reduce_ops = trait_config.get("reduce_op", {}).get("values", ["min","max","avg"])
output_indices = trait_config.get("output_index", {}).get("values", [True, False])
propagate_nans = trait_config.get("propagate_nan", {}).get("values", [True, False])
pooling_dims = trait_config.get("pooling_dim", {}).get("values", ["2d", "3d"])
all_combinations = list(
itertools.product(reduce_ops, output_indices, propagate_nans, pooling_dims)
)
# Filter valid combinations
combinations = []
for combo in all_combinations:
reduce_op, output_index, propagate_nan, pooling_dim = combo
if is_trait_combination_valid(
reduce_op, output_index, propagate_nan, pooling_dim
):
combinations.append(combo)
else:
logger.debug(
f"Skipping unsupported trait combination: {reduce_op}-{output_index}-{propagate_nan}-{pooling_dim}"
)
return combinations
def _get_dtype_string(self):
"""Get C++ type string for datatype."""
return get_dtype_string(self.datatype)
def _get_reduce_op_string(self, reduce_op):
"""Get C++ reduce op type string."""
return get_reduce_op_string(reduce_op)
def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True):
"""Generate a single kernel instance header"""
reduce_op, output_index, propagate_nan, pooling_dim = trait_combo
# Create kernel name
kernel_name = (
f"pool_{self.datatype}_{pooling_dim}_{reduce_op}_"
f"{'idx' if output_index else 'noidx'}_"
f"{'nan' if propagate_nan else 'nonan'}"
)
# Create tile configuration string
tile_str = (
f"{tile_config['block_m']}x{tile_config['block_n']}_"
f"{tile_config['warp_m']}x{tile_config['warp_n']}_"
f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}_"
f"{tile_config['thread_tile_m']}x{tile_config['thread_tile_n']}"
)
kernel_name += f"_{tile_str}"
# Determine types
in_type = self._get_dtype_string()
out_type = in_type
compute_type = "float" # Always use float for computation
index_type = "ck_tile::index_t"
reduce_op_type = self._get_reduce_op_string(reduce_op)
output_index_str = "true" if output_index else "false"
propagate_nan_str = "true" if propagate_nan else "false"
# Generate 2D or 3D specific code
if pooling_dim == "2d":
tensor_shape_type = "ck_tile::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>"
window_shape_type = "ck_tile::tuple<ck_tile::index_t, ck_tile::index_t>"
window_rank = 2
else:
tensor_shape_type = "ck_tile::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>"
window_shape_type = (
"ck_tile::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>"
)
window_rank = 3
pragma_line = "#pragma once\n" if is_header else ""
instance_code = f"""// Generated kernel instance for {kernel_name}
{pragma_line}
#include <cstdint>
#include <utility>
#include <tuple>
#include <iostream>
#include <stdexcept>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/pooling.hpp"
using InDataType = {in_type};
using OutDataType = {out_type};
using ComputeDataType = {compute_type};
using IndexDataType = {index_type};
using ReduceOpType = {reduce_op_type};
using TensorShape = {tensor_shape_type};
using WindowShape = {window_shape_type};
// Kernel name for display
constexpr const char* KERNEL_NAME = "{kernel_name}";
constexpr int POOLING_DIM = {window_rank};
// Wrapper for simplified launch interface
struct SelectedKernel {{
// Tile configuration - PoolShape parameters
static constexpr ck_tile::index_t Block_M = {tile_config["block_m"]};
static constexpr ck_tile::index_t Block_N = {tile_config["block_n"]};
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
static constexpr ck_tile::index_t WarpTile_M = {tile_config["warp_tile_m"]};
static constexpr ck_tile::index_t WarpTile_N = {tile_config["warp_tile_n"]};
static constexpr ck_tile::index_t ThreadTile_M = {tile_config["thread_tile_m"]};
static constexpr ck_tile::index_t ThreadTile_N = {tile_config["thread_tile_n"]};
// Traits
static constexpr bool kOutputIndex = {output_index_str};
static constexpr bool kPropagateNan = {propagate_nan_str};
// Pool shape
using BlockWarps = ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N>;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using WarpTile = ck_tile::sequence<WarpTile_M, WarpTile_N>;
using ThreadTile = ck_tile::sequence<ThreadTile_M, ThreadTile_N>;
using PoolShapeType = ck_tile::PoolShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
// Problem and kernel types
using Problem = ck_tile::PoolProblem<InDataType,
OutDataType,
ComputeDataType,
IndexDataType,
ReduceOpType,
kOutputIndex,
kPropagateNan,
PoolShapeType>;
using Kernel = ck_tile::PoolKernel<Problem>;
static float launch(ck_tile::PoolHostArgs<TensorShape, WindowShape>& args,
const ck_tile::stream_config& stream) {{
constexpr ck_tile::index_t kBlockPerCu = 1;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
auto kernel_args = Kernel::MakeKernelArgs(args);
if (!Kernel::IsSupportedArgument(kernel_args)) {{
throw std::runtime_error(
std::string("Unsupported arguments for pooling kernel: ") + KERNEL_NAME);
}}
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
if(stream.log_level_ > 0) {{
std::cout << "Launching pooling kernel: " << KERNEL_NAME << "\\n"
<< " grid_size: " << kGridSize << ", block_size: " << kBlockSize
<< std::endl;
}}
return ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, kGridSize, kBlockSize, 0, kernel_args));
}}
}};
"""
return kernel_name, instance_code
def write_kernel_list(self):
"""Write kernel list to file for CMake to read"""
tile_configs = self._get_tile_configs(fast_mode=False)
trait_combos = self._generate_trait_combinations()
kernel_list = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
reduce_op, output_index, propagate_nan, pooling_dim = trait_combo
kernel_name = (
f"pool_{self.datatype}_{pooling_dim}_{reduce_op}_"
f"{'idx' if output_index else 'noidx'}_"
f"{'nan' if propagate_nan else 'nonan'}"
)
tile_str = (
f"{tile_config['block_m']}x{tile_config['block_n']}_"
f"{tile_config['warp_m']}x{tile_config['warp_n']}_"
f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}_"
f"{tile_config['thread_tile_m']}x{tile_config['thread_tile_n']}"
)
kernel_name += f"_{tile_str}"
trait_str = (
f"{reduce_op}_"
f"{'true' if output_index else 'false'}_"
f"{'true' if propagate_nan else 'false'}_"
f"{pooling_dim}"
)
kernel_list.append(
{
"name": kernel_name,
"tile_config": tile_config,
"trait_combo": trait_combo,
"tile_str": tile_str,
"trait_str": trait_str,
}
)
# Write kernel count
with open(self.working_path / "pool_kernel_count.txt", "w") as f:
f.write(str(len(kernel_list)))
# Write kernel list
with open(self.working_path / "pool_kernel_list.txt", "w") as f:
for kernel in kernel_list:
f.write(
f"{kernel['name']}|{kernel['tile_str']}|{kernel['trait_str']}\n"
)
print(f"Listed {len(kernel_list)} kernel configurations")
def generate_individual(self, num_workers=None):
"""Generate individual kernel files with parallel processing"""
if num_workers is None:
num_workers = min(multiprocessing.cpu_count(), 8)
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
self.working_path,
self.datatype,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 10 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
kernel_list.sort(key=lambda x: x[0])
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def run(self, num_workers=None):
"""Run the builder to generate individual kernel files"""
self.generate_individual(num_workers)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
tile_config, trait_combo, working_path, datatype = work_item
builder = PoolingKernelBuilder(working_path, datatype)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
header_file = working_path / f"pooling_single_{kernel_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(
description="Pooling kernel instance builder for tile_engine"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument(
"--datatype",
required=True,
choices=["fp8", "fp16", "bf16", "fp32"],
help="Data type",
)
parser.add_argument("--config_json", help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_individual", action="store_true", help="Generate individual kernel files"
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
builder = PoolingKernelBuilder(args.working_path, args.datatype, args.config_json)
if args.list_kernels:
builder.write_kernel_list()
elif args.gen_single:
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config: "block_mx block_n_warp_mxwarp_n_warp_tile_mxwarp_tile_n_thread_tile_mxthread_tile_n"
tile_parts = args.tile_config.split("_")
block_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
thread_tile_dims = tile_parts[3].split("x")
tile_config = {
"block_m": int(block_dims[0]),
"block_n": int(block_dims[1]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"thread_tile_m": int(thread_tile_dims[0]),
"thread_tile_n": int(thread_tile_dims[1]),
}
# Parse trait combo: "reduce_op_output_index_propagate_nan_pooling_dim"
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # reduce_op
trait_parts[1].lower() == "true", # output_index
trait_parts[2].lower() == "true", # propagate_nan
trait_parts[3], # pooling_dim
)
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
header_file = builder.working_path / f"pooling_single_{kernel_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
print(f"Generated {header_file}")
elif args.gen_individual:
builder.run(args.num_workers)
else:
parser.error(
"Must specify one of: --list_kernels, --gen_individual, or --gen_single"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,487 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Validation utilities for pooling tile_engine configurations.
Validates tile configurations, trait combinations, and datatype support for
pooling kernels. Modelled after gemm_validation_utils.py — each constraint
from the CK PoolShape / PoolKernel static_asserts is mirrored here so that
invalid configs are rejected at code-generation time rather than at compile
or runtime.
"""
import logging
from typing import List, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Hardware constants
# ---------------------------------------------------------------------------
# Default warp size (wave64 for CDNA architectures)
WARP_SIZE = 64
MAX_BLOCK_SIZE = 1024 # Maximum threads per workgroup on AMD GPUs
MAX_LDS_BYTES = 65536 # 64 KB LDS per workgroup
def get_warp_size_for_gpu(gpu_target: str) -> int:
"""Get the warp size for a given GPU target.
CDNA architectures (gfx9xx) use WAVE64 (64 threads per wavefront).
RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use WAVE32 (32 threads per wavefront).
"""
if gpu_target.startswith("gfx9"):
return 64 # CDNA - WAVE64
return 32 # RDNA and others - WAVE32
# ---------------------------------------------------------------------------
# Datatype helpers
# ---------------------------------------------------------------------------
ELEMENT_SIZE_MAP = {
"fp8": 1,
"bf8": 1,
"int8": 1,
"fp16": 2,
"bf16": 2,
"int4": 0.5,
"int32": 4,
"fp32": 4,
"fp64": 8,
}
DTYPE_STRING_MAP = {
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float",
"fp64": "double",
}
SUPPORTED_DATATYPES = list(DTYPE_STRING_MAP.keys())
# ---------------------------------------------------------------------------
# Reduce-op helpers
# ---------------------------------------------------------------------------
REDUCE_OP_STRING_MAP = {
"max": "ck_tile::ReduceOp::Max",
"min": "ck_tile::ReduceOp::Min",
"avg": "ck_tile::ReduceOp::Add",
}
SUPPORTED_REDUCE_OPS = list(REDUCE_OP_STRING_MAP.keys())
SUPPORTED_POOLING_DIMS = ("2d", "3d")
# ---------------------------------------------------------------------------
# Public helper functions (used by the instance builder)
# ---------------------------------------------------------------------------
def element_size(datatype: str) -> float:
"""Return the byte-width of a single element for *datatype*."""
datatype = datatype.lower()
if datatype not in ELEMENT_SIZE_MAP:
raise ValueError(
f"Unsupported data type: '{datatype}'. "
f"Supported: {list(ELEMENT_SIZE_MAP.keys())}"
)
return ELEMENT_SIZE_MAP[datatype]
def get_dtype_string(datatype: str) -> str:
"""Return the C++ type string (e.g. ``ck_tile::fp16_t``) for *datatype*."""
return DTYPE_STRING_MAP.get(datatype, "float")
def get_reduce_op_string(reduce_op: str) -> str:
"""Return the C++ ReduceOp enumerator string for *reduce_op*."""
return REDUCE_OP_STRING_MAP.get(reduce_op, "ck_tile::ReduceOp::Max")
# ---------------------------------------------------------------------------
# Individual tile-config validators
# ---------------------------------------------------------------------------
def validate_positivity(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""All tile parameters must be positive integers."""
params = {
"block_m": block_m,
"block_n": block_n,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"thread_tile_m": thread_tile_m,
"thread_tile_n": thread_tile_n,
}
for name, val in params.items():
if val <= 0:
return False, f"{name} ({val}) must be > 0"
return True, ""
def validate_power_of_two(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""All tile parameters should be powers of two for correct GPU addressing."""
params = {
"block_m": block_m,
"block_n": block_n,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"thread_tile_m": thread_tile_m,
"thread_tile_n": thread_tile_n,
}
for name, val in params.items():
if val > 0 and (val & (val - 1)) != 0:
return False, f"{name} ({val}) is not a power of two"
return True, ""
def validate_thread_tile_alignment(
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""
Mirrors pool_shape.hpp:
static_assert(Warp_M % ThreadTile_M == 0);
static_assert(Warp_N % ThreadTile_N == 0);
"""
if warp_tile_m % thread_tile_m != 0:
return (
False,
f"warp_tile_m ({warp_tile_m}) must be divisible by "
f"thread_tile_m ({thread_tile_m})",
)
if warp_tile_n % thread_tile_n != 0:
return (
False,
f"warp_tile_n ({warp_tile_n}) must be divisible by "
f"thread_tile_n ({thread_tile_n})",
)
return True, ""
def validate_warp_thread_distribution(
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[bool, str]:
"""
Mirrors pool_shape.hpp:
static_assert((Warp_M * Warp_N / ThreadTile_M / ThreadTile_N)
% get_warp_size() == 0);
"""
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
if threads_per_warp % warp_size != 0:
return (
False,
f"(warp_tile_m * warp_tile_n) / (thread_tile_m * thread_tile_n) = "
f"{threads_per_warp} is not a multiple of warp_size ({warp_size})",
)
return True, ""
def _compute_warp_size_scale_factors(
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[int, int]:
"""
Reproduce the WarpSizeScaleFactor_M / _N logic from pool_shape.hpp.
"""
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
scale = threads_per_warp // warp_size
if warp_tile_m // thread_tile_m > warp_tile_n // thread_tile_n:
return scale, 1
return 1, scale
def validate_block_tile_coverage(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[bool, str]:
"""
Mirrors pool_shape.hpp:
static_assert((Block_M * WarpSizeScaleFactor_M) %
(WarpPerBlock_M * Warp_M) == 0);
static_assert((Block_N * WarpSizeScaleFactor_N) %
(WarpPerBlock_N * Warp_N) == 0);
"""
sf_m, sf_n = _compute_warp_size_scale_factors(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n, warp_size
)
if (block_m * sf_m) % (warp_m * warp_tile_m) != 0:
return (
False,
f"block_m*ScaleFactor_M ({block_m}*{sf_m}={block_m * sf_m}) must be "
f"divisible by warp_m*warp_tile_m ({warp_m}*{warp_tile_m}"
f"={warp_m * warp_tile_m})",
)
if (block_n * sf_n) % (warp_n * warp_tile_n) != 0:
return (
False,
f"block_n*ScaleFactor_N ({block_n}*{sf_n}={block_n * sf_n}) must be "
f"divisible by warp_n*warp_tile_n ({warp_n}*{warp_tile_n}"
f"={warp_n * warp_tile_n})",
)
return True, ""
def validate_block_size(
warp_m: int,
warp_n: int,
warp_size: int = WARP_SIZE,
) -> Tuple[bool, str]:
"""BlockSize = warp_size * warp_m * warp_n must be <= MAX_BLOCK_SIZE."""
block_size = warp_size * warp_m * warp_n
if block_size > MAX_BLOCK_SIZE:
return (
False,
f"BlockSize ({block_size} = {warp_size}*{warp_m}*{warp_n}) "
f"exceeds maximum ({MAX_BLOCK_SIZE})",
)
return True, ""
def validate_vector_load_alignment(
block_m: int,
thread_tile_m: int,
in_datatype: str,
) -> Tuple[bool, str]:
"""
The M-dimension thread-tile determines the contiguous vector load width.
It must produce a load whose byte-width divides 16 bytes (max global
vector load width on AMD GPUs) and is at least 1 element wide.
"""
elem_bytes = element_size(in_datatype)
load_bytes = thread_tile_m * elem_bytes
if load_bytes > 16:
return (
False,
f"thread_tile_m ({thread_tile_m}) * element_size({in_datatype}, "
f"{elem_bytes}B) = {load_bytes}B exceeds 16B max vector load",
)
if 16 % load_bytes != 0 and load_bytes % 16 != 0:
return (
False,
f"Vector load width ({load_bytes}B) is not a divisor of 16B",
)
return True, ""
def validate_repeat_factors(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
) -> Tuple[bool, str]:
"""
Repeat_M and Repeat_N from pool_shape.hpp must be >= 1. They are the
number of tile iterations each warp performs within the block.
"""
sf_m, sf_n = _compute_warp_size_scale_factors(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n
)
repeat_m = (block_m * sf_m) // (warp_m * warp_tile_m)
repeat_n = (block_n * sf_n) // (warp_n * warp_tile_n)
if repeat_m < 1:
return False, f"Repeat_M ({repeat_m}) must be >= 1"
if repeat_n < 1:
return False, f"Repeat_N ({repeat_n}) must be >= 1"
return True, ""
# ---------------------------------------------------------------------------
# Comprehensive tile-config validation (entry point)
# ---------------------------------------------------------------------------
def is_tile_config_valid(
block_m: int,
block_n: int,
warp_m: int,
warp_n: int,
warp_tile_m: int,
warp_tile_n: int,
thread_tile_m: int,
thread_tile_n: int,
in_datatype: str,
out_datatype: str,
fast_mode: bool = False,
gpu_target: str = "gfx90a",
) -> bool:
"""
Comprehensive pooling tile configuration validation.
When *fast_mode* is True only cheap sanity checks are performed (useful
for the ``--list_kernels`` path). Full mode mirrors every
``static_assert`` in ``pool_shape.hpp``.
Parameters
----------
block_m, block_n : Block tile dimensions (M = output elems, N = window).
warp_m, warp_n : Warps per block along each dimension.
warp_tile_m, warp_tile_n : Tile processed per warp.
thread_tile_m, thread_tile_n : Contiguous elements per thread.
in_datatype : Input element type (e.g. ``"fp16"``).
out_datatype : Output element type.
fast_mode : Skip expensive checks when True.
"""
all_params = (
block_m, block_n, warp_m, warp_n,
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n,
)
# --- Positivity (always) ---
ok, err = validate_positivity(*all_params)
if not ok:
logger.debug(f"Positivity check failed: {err}")
return False
# --- Thread-tile alignment (always) ---
ok, err = validate_thread_tile_alignment(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n
)
if not ok:
logger.debug(f"Thread tile alignment failed: {err}")
return False
if fast_mode:
return True
# Get the warp size for this GPU target
warp_size = get_warp_size_for_gpu(gpu_target)
# --- Power-of-two ---
ok, err = validate_power_of_two(*all_params)
if not ok:
logger.debug(f"Power-of-two check failed: {err}")
return False
# --- Warp-thread distribution ---
ok, err = validate_warp_thread_distribution(
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n, warp_size
)
if not ok:
logger.debug(f"Warp thread distribution failed: {err}")
return False
# --- Block-tile coverage ---
ok, err = validate_block_tile_coverage(*all_params, warp_size=warp_size)
if not ok:
logger.debug(f"Block tile coverage failed: {err}")
return False
# --- Block size ---
ok, err = validate_block_size(warp_m, warp_n, warp_size)
if not ok:
logger.debug(f"Block size check failed: {err}")
return False
# --- Repeat factors ---
ok, err = validate_repeat_factors(*all_params)
if not ok:
logger.debug(f"Repeat factor check failed: {err}")
return False
# --- Vector load alignment ---
ok, err = validate_vector_load_alignment(block_m, thread_tile_m, in_datatype)
if not ok:
logger.debug(f"Vector load alignment failed: {err}")
return False
return True
# ---------------------------------------------------------------------------
# Trait-combination validation
# ---------------------------------------------------------------------------
def is_trait_combination_valid(
reduce_op: str,
output_index: bool,
propagate_nan: bool,
pooling_dim: str,
) -> bool:
"""
Validate a pooling trait combination.
Parameters
----------
reduce_op : ``"max"``, ``"min"``, or ``"avg"``.
output_index : Whether to output indices of the selected elements.
propagate_nan: Whether to propagate NaN values through the reduction.
pooling_dim : ``"2d"`` or ``"3d"``.
"""
if reduce_op not in SUPPORTED_REDUCE_OPS:
logger.debug(f"Unsupported reduce_op: '{reduce_op}'")
return False
if pooling_dim not in SUPPORTED_POOLING_DIMS:
logger.debug(f"Invalid pooling dimension: '{pooling_dim}'")
return False
# output_index only makes sense for max pooling (CK constraint)
if output_index and reduce_op != "max":
logger.debug(
f"output_index=True is only supported for 'max' pooling, "
f"not '{reduce_op}'"
)
return False
return True
# ---------------------------------------------------------------------------
# Datatype validation
# ---------------------------------------------------------------------------
def is_datatype_supported(datatype: str) -> bool:
"""Return True if *datatype* is a known pooling datatype."""
return datatype.lower() in ELEMENT_SIZE_MAP

View File

@@ -11,7 +11,7 @@ set(MULTI_REDUCE_VARIANTS "multiops_multiblock;multiops_threadwise" CACHE STRING
function(build_multi_reduce_for_datatype datatype variant)
# Filter GPU targets to only gfx942, and gfx950
set(GPU_TARGETS "")
set(DESIRED_TARGETS "gfx942;gfx950")
set(DESIRED_TARGETS "gfx942;gfx950;gfx12-generic")
set(VALID_VARIANTS "multiops_multiblock;multiops_threadwise")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
@@ -22,7 +22,7 @@ function(build_multi_reduce_for_datatype datatype variant)
# Skip compilation if no matching targets found
if(NOT GPU_TARGETS)
message(WARNING "Skipping Tile Engine for Multi Reduction Kernel: No supported GPU targets (gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine for Multi Reduction Kernel: No supported GPU targets (gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()