mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#4469 (commit 0844cb0)
[CK_TILE] Add pooling in tile_engine ## Motivation <!-- Explain the purpose of this PR and the goals it aims to achieve. --> Add pooling in ck tile engine ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
791afc6465
commit
119712bd90
@@ -7,5 +7,6 @@ include_directories(BEFORE
|
||||
|
||||
add_subdirectory(ops/gemm)
|
||||
add_subdirectory(ops/gemm_streamk)
|
||||
add_subdirectory(ops/pooling)
|
||||
add_subdirectory(ops/reduce)
|
||||
|
||||
|
||||
@@ -231,7 +231,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942, gfx950
|
||||
set(GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx12-generic")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
@@ -242,7 +242,7 @@ endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM Multi D build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(WARNING "Skipping Tile Engine GEMM Multi D build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(VERBOSE "Building individual GEMM Multi D targets for GPU targets: ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
|
||||
@@ -219,7 +219,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942, and gfx950
|
||||
set(GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx12-generic")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
@@ -230,7 +230,7 @@ endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(VERBOSE "Building individual GEMM Preshuffle targets for GPU targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
|
||||
@@ -226,7 +226,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
|
||||
set(GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201;gfx12-generic")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
@@ -237,7 +237,7 @@ endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(VERBOSE "Building individual GEMM Universal targets for GPU targets: ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
|
||||
@@ -25,6 +25,16 @@ ELEMENT_SIZE_MAP = {
|
||||
"fp64": 8,
|
||||
}
|
||||
|
||||
def get_warp_size_for_gpu(gpu_target: str) -> int:
|
||||
"""Get the warp size for a given GPU target.
|
||||
|
||||
CDNA architectures (gfx9xx) use WAVE64 (64 threads per wavefront).
|
||||
RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use WAVE32 (32 threads per wavefront).
|
||||
"""
|
||||
if gpu_target.startswith("gfx9"):
|
||||
return 64 # CDNA - WAVE64
|
||||
return 32 # RDNA and others - WAVE32
|
||||
|
||||
WARP_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": [
|
||||
[1, 4, 1],
|
||||
@@ -586,10 +596,11 @@ def validate_whole_wg_cover_configuration(
|
||||
layout,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
gpu_target: str = "gfx90a",
|
||||
) -> Tuple[bool, str]:
|
||||
# Validate whole workgroup cover configuration
|
||||
|
||||
warp_size = 64
|
||||
warp_size = get_warp_size_for_gpu(gpu_target)
|
||||
NumWarps = warp_m * warp_n * warp_k
|
||||
BlockSize = NumWarps * warp_size
|
||||
|
||||
@@ -704,6 +715,73 @@ def wg_cover_core_validation(
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_cshuffle_epilogue_distribution(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_size: int,
|
||||
c_datatype: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate that the CShuffleEpilogue tile distribution pattern is valid.
|
||||
|
||||
This mirrors the static_assert in static_encoding_pattern.hpp:
|
||||
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
|
||||
|
||||
The CShuffleEpilogue creates a tile_distribution_encoding_pattern_2d<BlockSize, YPerTile, XPerTile, VecSize, thread_raked>
|
||||
where:
|
||||
- BlockSize = warp_m * warp_n * warp_k * warp_size
|
||||
- YPerTile = MPerIterationShuffle (derived from tile_m / (warp_m * warp_tile_m / some_factor))
|
||||
- XPerTile = NPerIterationShuffle (derived from tile_n)
|
||||
- VecSize = vector size based on element size (typically 8 for fp16)
|
||||
|
||||
The key constraint is that X0 must evenly divide warp_size, where:
|
||||
- X0 = min(warp_size, XPerTile / X1)
|
||||
- X1 = min(VecSize, LargestVec)
|
||||
- LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
|
||||
"""
|
||||
NumWarps = warp_m * warp_n * warp_k
|
||||
BlockSize = NumWarps * warp_size
|
||||
|
||||
elem_size = ELEMENT_SIZE_MAP.get(c_datatype, 2)
|
||||
VecSize = 16 // elem_size
|
||||
|
||||
XPerTile = tile_n
|
||||
YPerTile = tile_m // warp_m
|
||||
|
||||
if XPerTile <= 0 or YPerTile <= 0:
|
||||
return False, f"Invalid tile dimensions: XPerTile={XPerTile}, YPerTile={YPerTile}"
|
||||
|
||||
num_warps = BlockSize // warp_size
|
||||
if num_warps * warp_size == 0:
|
||||
return False, "Invalid BlockSize or warp_size"
|
||||
|
||||
LargestVec = (XPerTile * YPerTile) // (num_warps * warp_size)
|
||||
if LargestVec <= 0:
|
||||
LargestVec = 1
|
||||
|
||||
X1 = min(VecSize, LargestVec) if LargestVec > 0 else VecSize
|
||||
if X1 <= 0:
|
||||
X1 = 1
|
||||
|
||||
X0 = min(warp_size, XPerTile // X1) if X1 > 0 else warp_size
|
||||
|
||||
Y1 = warp_size // X0 if X0 > 0 else 0
|
||||
|
||||
if X0 * Y1 != warp_size:
|
||||
return (
|
||||
False,
|
||||
f"CShuffleEpilogue distribution invalid: X0({X0}) * Y1({Y1}) = {X0 * Y1} != warp_size({warp_size}). "
|
||||
f"XPerTile={XPerTile}, YPerTile={YPerTile}, VecSize={VecSize}, BlockSize={BlockSize}"
|
||||
)
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def get_global_vector_load_size(
|
||||
BlockSize: int,
|
||||
KPerBlock: int,
|
||||
@@ -766,6 +844,8 @@ def validate_gemm(
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
# GEMM Validation
|
||||
warp_size = get_warp_size_for_gpu(gpu_target)
|
||||
|
||||
# Validate whole workgroup cover configuration
|
||||
whole_workgroup_cover_valid, whole_workgroup_cover_error = (
|
||||
validate_whole_wg_cover_configuration(
|
||||
@@ -778,6 +858,7 @@ def validate_gemm(
|
||||
layout,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
gpu_target,
|
||||
)
|
||||
)
|
||||
if not whole_workgroup_cover_valid:
|
||||
@@ -786,6 +867,23 @@ def validate_gemm(
|
||||
)
|
||||
return False, whole_workgroup_cover_error
|
||||
|
||||
# Validate CShuffleEpilogue distribution pattern (for cshuffle epilogue)
|
||||
# This validation ensures the tile distribution pattern is valid for the output tile
|
||||
cshuffle_valid, cshuffle_error = validate_cshuffle_epilogue_distribution(
|
||||
tile_m,
|
||||
tile_n,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_size,
|
||||
c_datatype,
|
||||
)
|
||||
if not cshuffle_valid:
|
||||
logging.debug(f"CShuffleEpilogue validation failed: {cshuffle_error}")
|
||||
return False, cshuffle_error
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
@@ -808,6 +906,8 @@ def validate_gemm_preshuffle(
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
# Preshuffle Validations
|
||||
warp_size = get_warp_size_for_gpu(gpu_target)
|
||||
|
||||
# Validate vector load alignment
|
||||
m_iter_per_warp = tile_m / (warp_m * warp_tile_m)
|
||||
vector_valid, vector_error = validate_vector_load_alignment(
|
||||
@@ -815,7 +915,7 @@ def validate_gemm_preshuffle(
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
m_iter_per_warp,
|
||||
wave_size=64,
|
||||
wave_size=warp_size,
|
||||
vector_load_size=16,
|
||||
)
|
||||
if not vector_valid:
|
||||
@@ -831,7 +931,7 @@ def validate_gemm_preshuffle(
|
||||
warp_k,
|
||||
a_datatype,
|
||||
vector_load_size=16,
|
||||
warp_size=64,
|
||||
warp_size=warp_size,
|
||||
)
|
||||
if not m0_m1_m2_valid:
|
||||
logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}")
|
||||
|
||||
@@ -226,7 +226,7 @@ message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx942, gfx950
|
||||
set(GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx942;gfx950")
|
||||
set(DESIRED_TARGETS "gfx942;gfx950;gfx12-generic")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
@@ -237,7 +237,7 @@ endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine Grouped GEMM build: No supported GPU targets (gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(WARNING "Skipping Tile Engine Grouped GEMM build: No supported GPU targets (gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(VERBOSE "Building individual Grouped GEMM targets for GPU targets: ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
|
||||
@@ -216,7 +216,7 @@ message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942
|
||||
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942") # TODO: Add gfx950 when supported
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx12-generic") # TODO: Add gfx950 when supported
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
@@ -227,7 +227,7 @@ endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
|
||||
212
tile_engine/ops/pooling/CMakeLists.txt
Normal file
212
tile_engine/ops/pooling/CMakeLists.txt
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ============================================================================
|
||||
# Pooling Tile Engine Build Configuration
|
||||
#
|
||||
# Generates individual benchmark executables for pooling kernels
|
||||
# ============================================================================
|
||||
|
||||
set(POOLING_DATATYPE "fp8;fp16;fp32" CACHE STRING "List of datatypes for Pooling (semicolon-separated)")
|
||||
set(POOLING_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
option(ENABLE_CCACHE_POOLING "Enable ccache for pooling ops compilation" OFF)
|
||||
|
||||
# Store the directory path for use in functions
|
||||
set(POOLING_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# ============================================================================
|
||||
# create_individual_pool_target
|
||||
#
|
||||
# Creates a single benchmark executable for a specific pooling kernel config.
|
||||
# ============================================================================
|
||||
function(create_individual_pool_target datatype kernel_name trait tile_config config_json)
|
||||
if(NOT POOLING_GPU_TARGETS)
|
||||
message(WARNING "Skipping individual pooling target: No supported GPU targets")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(target_name "benchmark_pooling_${datatype}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}")
|
||||
# HIP clang offload uses temporary files derived from the input source basename.
|
||||
# When many targets compile the same source filename in parallel, temporary
|
||||
# files can collide and corrupt each other. Use a unique copied source per target.
|
||||
set(target_source "${CMAKE_CURRENT_BINARY_DIR}/${target_name}_pooling_benchmark_single.cpp")
|
||||
|
||||
# Generated header path - use kernel_name from pool_kernel_list.txt to match
|
||||
# the filename generated by pooling_instance_builder.py
|
||||
set(instance_header "${working_path}/pooling_single_${kernel_name}.hpp")
|
||||
|
||||
# Add custom command to generate the header file at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${instance_header}
|
||||
COMMAND ${Python3_EXECUTABLE} ${POOLING_SOURCE_DIR}/pooling_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--config_json ${config_json}
|
||||
--gen_single
|
||||
--kernel_name "${kernel_name}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
DEPENDS ${POOLING_SOURCE_DIR}/pooling_instance_builder.py ${config_json}
|
||||
COMMENT "Generating ${instance_header}"
|
||||
)
|
||||
|
||||
configure_file(${POOLING_SOURCE_DIR}/pooling_benchmark_single.cpp ${target_source} COPYONLY)
|
||||
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
${target_source}
|
||||
${instance_header}
|
||||
)
|
||||
|
||||
# Set GPU architectures
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${POOLING_GPU_TARGETS})
|
||||
|
||||
# Set compile definitions
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
POOLING_SINGLE_INSTANCE_HPP="${instance_header}"
|
||||
)
|
||||
|
||||
# Include directories
|
||||
target_include_directories(${target_name} PRIVATE
|
||||
${POOLING_SOURCE_DIR}
|
||||
${working_path}
|
||||
)
|
||||
|
||||
# Compile options
|
||||
target_compile_options(${target_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
-include ${instance_header}
|
||||
)
|
||||
|
||||
# Add FP8 format definitions if needed
|
||||
if(CK_USE_OCP_FP8)
|
||||
target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
# Add to collection targets
|
||||
add_dependencies(benchmark_pooling_all ${target_name})
|
||||
add_dependencies(benchmark_pooling_${datatype} ${target_name})
|
||||
|
||||
message(STATUS " Created pooling benchmark target: ${target_name}")
|
||||
endfunction()
|
||||
|
||||
# ============================================================================
|
||||
# build_individual_pool_targets
|
||||
#
|
||||
# Builds all benchmark targets for a specific datatype.
|
||||
# ============================================================================
|
||||
function(build_individual_pool_targets datatype)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}")
|
||||
|
||||
# Choose config file
|
||||
if(DEFINED ENV{POOLING_CONFIG_FILE} AND NOT "$ENV{POOLING_CONFIG_FILE}" STREQUAL "")
|
||||
set(config_filename "$ENV{POOLING_CONFIG_FILE}")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
|
||||
message(STATUS " Using config from environment variable: ${config_filename}")
|
||||
elseif(NOT "${POOLING_CONFIG_FILE}" STREQUAL "")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${POOLING_CONFIG_FILE}")
|
||||
message(STATUS " Using custom config: ${POOLING_CONFIG_FILE}")
|
||||
else()
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
message(STATUS " Using default config for pooling")
|
||||
endif()
|
||||
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(FATAL_ERROR "Config file not found: ${json_blob}")
|
||||
endif()
|
||||
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
# Step 1: List kernels
|
||||
message(STATUS " Listing pooling kernel configurations for ${datatype}...")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/pooling_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--config_json ${json_blob}
|
||||
--list_kernels
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list pooling kernels for ${datatype}: ${list_error}")
|
||||
endif()
|
||||
|
||||
# Read kernel count
|
||||
if(EXISTS ${working_path}/pool_kernel_count.txt)
|
||||
file(READ ${working_path}/pool_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(STATUS " Found ${kernel_count} pooling kernel configurations")
|
||||
else()
|
||||
message(FATAL_ERROR "Pooling kernel count file not found")
|
||||
endif()
|
||||
|
||||
# Step 2: Create targets
|
||||
if(EXISTS ${working_path}/pool_kernel_list.txt)
|
||||
file(STRINGS ${working_path}/pool_kernel_list.txt kernel_lines)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(LENGTH parts parts_len)
|
||||
if(parts_len EQUAL 3)
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
create_individual_pool_target("${datatype}" "${kernel_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
endif()
|
||||
endforeach()
|
||||
else()
|
||||
message(FATAL_ERROR "Pooling kernel list file not found")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# ============================================================================
|
||||
# MAIN EXECUTION
|
||||
# ============================================================================
|
||||
|
||||
message(STATUS "=== Starting Tile Engine Pooling Configuration ===")
|
||||
message(STATUS "POOLING_DATATYPE: ${POOLING_DATATYPE}")
|
||||
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets
|
||||
set(POOLING_GPU_TARGETS "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND POOLING_GPU_TARGETS ${target})
|
||||
message(STATUS " Adding GPU target for pooling: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(NOT POOLING_GPU_TARGETS)
|
||||
message(WARNING "Skipping Tile Engine Pooling build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(STATUS "Building pooling targets for GPU targets: ${POOLING_GPU_TARGETS}")
|
||||
|
||||
# Enable ccache if requested
|
||||
if(ENABLE_CCACHE_POOLING)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
|
||||
message(STATUS "Using ccache for pooling compilation")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Create collection targets
|
||||
add_custom_target(benchmark_pooling_all)
|
||||
|
||||
foreach(dt IN LISTS POOLING_DATATYPE)
|
||||
add_custom_target(benchmark_pooling_${dt})
|
||||
endforeach()
|
||||
|
||||
# Build targets for each datatype
|
||||
foreach(dt IN LISTS POOLING_DATATYPE)
|
||||
build_individual_pool_targets(${dt})
|
||||
endforeach()
|
||||
endif()
|
||||
21
tile_engine/ops/pooling/configs/default_config.json
Normal file
21
tile_engine/ops/pooling/configs/default_config.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"problem": {
|
||||
"description": "Default pooling configuration for tile_engine benchmarks"
|
||||
},
|
||||
"tile_config": {
|
||||
"block_m": {"values": [64,128,256]},
|
||||
"block_n": {"values": [1,2]},
|
||||
"warp_m": {"values": [1]},
|
||||
"warp_n": {"values": [1]},
|
||||
"warp_tile_m": {"values": [128]},
|
||||
"warp_tile_n": {"values": [1]},
|
||||
"thread_tile_m": {"values": [1,2,4]},
|
||||
"thread_tile_n": {"values": [1]}
|
||||
},
|
||||
"trait_config": {
|
||||
"reduce_op": {"values": ["max", "min", "avg"]},
|
||||
"output_index": {"values": [true, false]},
|
||||
"propagate_nan": {"values": [true, false]},
|
||||
"pooling_dim": {"values": ["2d", "3d"]}
|
||||
}
|
||||
}
|
||||
132
tile_engine/ops/pooling/pooling_benchmark.hpp
Normal file
132
tile_engine/ops/pooling/pooling_benchmark.hpp
Normal file
@@ -0,0 +1,132 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <cmath>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/pooling.hpp"
|
||||
#include "ck_tile/host/reference/reference_pool.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Performance metrics for benchmarking
|
||||
enum class PoolMetric
|
||||
{
|
||||
LATENCY,
|
||||
BANDWIDTH
|
||||
};
|
||||
|
||||
/// @brief Pooling problem specification for 2D pooling
|
||||
struct PoolProblem2D
|
||||
{
|
||||
index_t N, H, W, C; // Input dimensions (NHWC)
|
||||
index_t Y, X; // Window dimensions
|
||||
index_t stride_h, stride_w; // Window strides
|
||||
index_t dilation_h, dilation_w; // Window dilations
|
||||
index_t pad_h_left, pad_h_right; // Height padding
|
||||
index_t pad_w_left, pad_w_right; // Width padding
|
||||
std::string datatype; // Data type name
|
||||
std::string reduce_op; // "max", "min", or "avg"
|
||||
|
||||
index_t Ho() const
|
||||
{
|
||||
index_t Ys = (Y - 1) * dilation_h + 1;
|
||||
return (H + pad_h_left + pad_h_right - Ys) / stride_h + 1;
|
||||
}
|
||||
|
||||
index_t Wo() const
|
||||
{
|
||||
index_t Xs = (X - 1) * dilation_w + 1;
|
||||
return (W + pad_w_left + pad_w_right - Xs) / stride_w + 1;
|
||||
}
|
||||
|
||||
index_t input_elements() const { return N * H * W * C; }
|
||||
index_t output_elements() const { return N * Ho() * Wo() * C; }
|
||||
|
||||
std::string to_string() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "N" << N << "_H" << H << "_W" << W << "_C" << C << "_Y" << Y << "_X" << X << "_Sh"
|
||||
<< stride_h << "_Sw" << stride_w << "_Dh" << dilation_h << "_Dw" << dilation_w;
|
||||
if(pad_h_left > 0 || pad_w_left > 0)
|
||||
oss << "_Ph" << pad_h_left << "_Pw" << pad_w_left;
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Pooling problem specification for 3D pooling
|
||||
struct PoolProblem3D
|
||||
{
|
||||
index_t N, D, H, W, C; // Input dimensions (NDHWC)
|
||||
index_t Z, Y, X; // Window dimensions
|
||||
index_t stride_d, stride_h, stride_w; // Window strides
|
||||
index_t dilation_d, dilation_h, dilation_w; // Window dilations
|
||||
index_t pad_d_left, pad_d_right; // Depth padding
|
||||
index_t pad_h_left, pad_h_right; // Height padding
|
||||
index_t pad_w_left, pad_w_right; // Width padding
|
||||
std::string datatype; // Data type name
|
||||
std::string reduce_op; // "max", "min", or "avg"
|
||||
|
||||
index_t Do() const
|
||||
{
|
||||
index_t Zs = (Z - 1) * dilation_d + 1;
|
||||
return (D + pad_d_left + pad_d_right - Zs) / stride_d + 1;
|
||||
}
|
||||
|
||||
index_t Ho() const
|
||||
{
|
||||
index_t Ys = (Y - 1) * dilation_h + 1;
|
||||
return (H + pad_h_left + pad_h_right - Ys) / stride_h + 1;
|
||||
}
|
||||
|
||||
index_t Wo() const
|
||||
{
|
||||
index_t Xs = (X - 1) * dilation_w + 1;
|
||||
return (W + pad_w_left + pad_w_right - Xs) / stride_w + 1;
|
||||
}
|
||||
|
||||
index_t input_elements() const { return N * D * H * W * C; }
|
||||
index_t output_elements() const { return N * Do() * Ho() * Wo() * C; }
|
||||
|
||||
std::string to_string() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "N" << N << "_D" << D << "_H" << H << "_W" << W << "_C" << C << "_Z" << Z << "_Y"
|
||||
<< Y << "_X" << X;
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Performance result for a pooling kernel
|
||||
struct PoolPerformanceResult
|
||||
{
|
||||
float latency_ms;
|
||||
float bandwidth_gb_s;
|
||||
|
||||
std::string to_string() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "latency=" << latency_ms << "ms, bandwidth=" << bandwidth_gb_s << "GB/s";
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Benchmark settings
|
||||
struct PoolBenchmarkSetting
|
||||
{
|
||||
int warmup = 5;
|
||||
int repeat = 20;
|
||||
bool verify = true;
|
||||
int init_method = 0; // 0: uniform random, 1: integer sequence, 2: constant, 3: special
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
390
tile_engine/ops/pooling/pooling_benchmark_single.cpp
Normal file
390
tile_engine/ops/pooling/pooling_benchmark_single.cpp
Normal file
@@ -0,0 +1,390 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* @file pooling_benchmark_single.cpp
|
||||
* @brief Single-kernel benchmark for pooling operations (2D and 3D).
|
||||
*
|
||||
* This benchmark includes the generated kernel header via -include flag
|
||||
* and runs the pooling kernel with specified problem sizes.
|
||||
*
|
||||
* The generated header provides:
|
||||
* - SelectedKernel (struct with ::launch())
|
||||
* - KERNEL_NAME (constexpr const char*)
|
||||
* - POOLING_DIM (constexpr int, 2 or 3)
|
||||
* - InDataType, OutDataType, ComputeDataType, IndexDataType, ReduceOpType
|
||||
* - TensorShape, WindowShape
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/pooling.hpp"
|
||||
#include "ck_tile/host/reference/reference_pool.hpp"
|
||||
#include "pooling_common.hpp"
|
||||
#include "pooling_benchmark.hpp"
|
||||
|
||||
// The kernel header is included via compile command line with -include flag.
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Benchmark implementation — templated on pooling dimension so that only
|
||||
// the matching branch is instantiated (2D or 3D).
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
template <typename HostArgs>
|
||||
static float launch_selected_kernel(HostArgs& args, const ck_tile::stream_config& stream)
|
||||
{
|
||||
return SelectedKernel::launch(args, stream);
|
||||
}
|
||||
|
||||
template <int PoolDim>
|
||||
static int benchmark_pooling(int argc, char* argv[])
|
||||
{
|
||||
if constexpr(PoolDim == 2)
|
||||
{
|
||||
// ---- 2D argument parser ----
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("n", "1", "Batch size (N)")
|
||||
.insert("h", "16", "Input height (H)")
|
||||
.insert("w", "16", "Input width (W)")
|
||||
.insert("c", "32", "Channels (C)")
|
||||
.insert("wy", "2", "Window height (Y)")
|
||||
.insert("wx", "2", "Window width (X)")
|
||||
.insert("sy", "2", "Window stride height")
|
||||
.insert("sx", "2", "Window stride width")
|
||||
.insert("dy", "1", "Window dilation height")
|
||||
.insert("dx", "1", "Window dilation width")
|
||||
.insert("phy", "0", "Padding height left")
|
||||
.insert("phyr", "0", "Padding height right")
|
||||
.insert("pwx", "0", "Padding width left")
|
||||
.insert("pwxr", "0", "Padding width right")
|
||||
.insert("verify", "1", "Verify results (0/1)")
|
||||
.insert("warmup", "5", "Warmup iterations")
|
||||
.insert("repeat", "20", "Repeat iterations")
|
||||
.insert("log", "1", "Log level");
|
||||
|
||||
if(!arg_parser.parse(argc, argv))
|
||||
return -1;
|
||||
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t H = arg_parser.get_int("h");
|
||||
ck_tile::index_t W = arg_parser.get_int("w");
|
||||
ck_tile::index_t C = arg_parser.get_int("c");
|
||||
ck_tile::index_t Y = arg_parser.get_int("wy");
|
||||
ck_tile::index_t X = arg_parser.get_int("wx");
|
||||
ck_tile::index_t Sy = arg_parser.get_int("sy");
|
||||
ck_tile::index_t Sx = arg_parser.get_int("sx");
|
||||
ck_tile::index_t Dy = arg_parser.get_int("dy");
|
||||
ck_tile::index_t Dx = arg_parser.get_int("dx");
|
||||
ck_tile::index_t LeftPy = arg_parser.get_int("phy");
|
||||
ck_tile::index_t RightPy = arg_parser.get_int("phyr");
|
||||
ck_tile::index_t LeftPx = arg_parser.get_int("pwx");
|
||||
ck_tile::index_t RightPx = arg_parser.get_int("pwxr");
|
||||
|
||||
bool verify = arg_parser.get_int("verify") != 0;
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int log_level = arg_parser.get_int("log");
|
||||
|
||||
ck_tile::index_t Ys = (Y - 1) * Dy + 1;
|
||||
ck_tile::index_t Xs = (X - 1) * Dx + 1;
|
||||
ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
|
||||
ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
|
||||
|
||||
std::cout << "Pooling 2D benchmark: " << KERNEL_NAME << std::endl;
|
||||
std::cout << " Input: NHWC = " << N << "x" << H << "x" << W << "x" << C << std::endl;
|
||||
std::cout << " Output: NHWC = " << N << "x" << Ho << "x" << Wo << "x" << C << std::endl;
|
||||
std::cout << " Window: " << Y << "x" << X << ", stride: " << Sy << "x" << Sx
|
||||
<< ", dilation: " << Dy << "x" << Dx << std::endl;
|
||||
|
||||
ck_tile::HostTensor<InDataType> h_in({N, H, W, C});
|
||||
ck_tile::HostTensor<OutDataType> h_out({N, Ho, Wo, C});
|
||||
ck_tile::HostTensor<OutDataType> h_out_ref({N, Ho, Wo, C});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_index({N, Ho, Wo, C});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_ref_index({N, Ho, Wo, C});
|
||||
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
|
||||
|
||||
ck_tile::DeviceMem d_in(h_in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out(h_out.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_index(h_out_index.get_element_space_size_in_bytes());
|
||||
|
||||
d_in.ToDevice(h_in.data());
|
||||
d_out.SetZero();
|
||||
d_out_index.SetZero();
|
||||
|
||||
auto input_shape = ck_tile::make_tuple(N, H, W, C);
|
||||
auto output_shape = ck_tile::make_tuple(N, Ho, Wo, C);
|
||||
auto input_strides = ck_tile::make_tuple(H * W * C, W * C, C, ck_tile::index_t{1});
|
||||
auto output_strides = ck_tile::make_tuple(Ho * Wo * C, Wo * C, C, ck_tile::index_t{1});
|
||||
auto window_lengths = ck_tile::make_tuple(Y, X);
|
||||
auto window_strides = ck_tile::make_tuple(Sy, Sx);
|
||||
auto window_dilations = ck_tile::make_tuple(Dy, Dx);
|
||||
auto input_left_pads = ck_tile::make_tuple(LeftPy, LeftPx);
|
||||
auto input_right_pads = ck_tile::make_tuple(RightPy, RightPx);
|
||||
|
||||
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
|
||||
d_in.GetDeviceBuffer(),
|
||||
d_out.GetDeviceBuffer(),
|
||||
d_out_index.GetDeviceBuffer(),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
ck_tile::stream_config stream{nullptr, true, log_level, warmup, repeat};
|
||||
|
||||
float latency = 0;
|
||||
try
|
||||
{
|
||||
latency = launch_selected_kernel(host_args, stream);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Kernel launch failed: " << e.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
size_t bytes_read = static_cast<size_t>(N) * H * W * C * sizeof(InDataType);
|
||||
size_t bytes_written = static_cast<size_t>(N) * Ho * Wo * C * sizeof(OutDataType);
|
||||
float bandwidth = (bytes_read + bytes_written) / (latency * 1e-3f) / 1e9f;
|
||||
|
||||
std::cout << " Latency: " << latency << " ms" << std::endl;
|
||||
std::cout << " Bandwidth: " << bandwidth << " GB/s" << std::endl;
|
||||
|
||||
if(verify)
|
||||
{
|
||||
d_out.FromDevice(h_out.data());
|
||||
d_out_index.FromDevice(h_out_index.data());
|
||||
|
||||
auto kernel_args =
|
||||
ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
|
||||
h_in.data(),
|
||||
h_out_ref.data(),
|
||||
h_out_ref_index.data(),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
ck_tile::reference_pool2d<InDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
decltype(input_shape),
|
||||
decltype(window_lengths),
|
||||
SelectedKernel::kOutputIndex>(
|
||||
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
|
||||
|
||||
bool pass_value =
|
||||
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-3, 1e-3);
|
||||
std::cout << " Verification: " << (pass_value ? "PASS" : "FAIL") << std::endl;
|
||||
|
||||
if(SelectedKernel::kOutputIndex)
|
||||
{
|
||||
bool pass_index = ck_tile::check_err(
|
||||
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
|
||||
std::cout << " Index verification: " << (pass_index ? "PASS" : "FAIL")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
else // PoolDim == 3
|
||||
{
|
||||
// ---- 3D argument parser ----
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("n", "1", "Batch size (N)")
|
||||
.insert("d", "4", "Input depth (D)")
|
||||
.insert("h", "16", "Input height (H)")
|
||||
.insert("w", "16", "Input width (W)")
|
||||
.insert("c", "32", "Channels (C)")
|
||||
.insert("wz", "2", "Window depth (Z)")
|
||||
.insert("wy", "2", "Window height (Y)")
|
||||
.insert("wx", "2", "Window width (X)")
|
||||
.insert("sz", "2", "Window stride depth")
|
||||
.insert("sy", "2", "Window stride height")
|
||||
.insert("sx", "2", "Window stride width")
|
||||
.insert("dz", "1", "Window dilation depth")
|
||||
.insert("dy", "1", "Window dilation height")
|
||||
.insert("dx", "1", "Window dilation width")
|
||||
.insert("pdz", "0", "Padding depth left")
|
||||
.insert("pdzr", "0", "Padding depth right")
|
||||
.insert("phy", "0", "Padding height left")
|
||||
.insert("phyr", "0", "Padding height right")
|
||||
.insert("pwx", "0", "Padding width left")
|
||||
.insert("pwxr", "0", "Padding width right")
|
||||
.insert("verify", "1", "Verify results (0/1)")
|
||||
.insert("warmup", "5", "Warmup iterations")
|
||||
.insert("repeat", "20", "Repeat iterations")
|
||||
.insert("log", "1", "Log level");
|
||||
|
||||
if(!arg_parser.parse(argc, argv))
|
||||
return -1;
|
||||
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t D = arg_parser.get_int("d");
|
||||
ck_tile::index_t H = arg_parser.get_int("h");
|
||||
ck_tile::index_t W = arg_parser.get_int("w");
|
||||
ck_tile::index_t C = arg_parser.get_int("c");
|
||||
ck_tile::index_t Z = arg_parser.get_int("wz");
|
||||
ck_tile::index_t Y = arg_parser.get_int("wy");
|
||||
ck_tile::index_t X = arg_parser.get_int("wx");
|
||||
ck_tile::index_t Sz = arg_parser.get_int("sz");
|
||||
ck_tile::index_t Sy = arg_parser.get_int("sy");
|
||||
ck_tile::index_t Sx = arg_parser.get_int("sx");
|
||||
ck_tile::index_t Dz = arg_parser.get_int("dz");
|
||||
ck_tile::index_t Dy = arg_parser.get_int("dy");
|
||||
ck_tile::index_t Dx = arg_parser.get_int("dx");
|
||||
ck_tile::index_t LeftPz = arg_parser.get_int("pdz");
|
||||
ck_tile::index_t RightPz = arg_parser.get_int("pdzr");
|
||||
ck_tile::index_t LeftPy = arg_parser.get_int("phy");
|
||||
ck_tile::index_t RightPy = arg_parser.get_int("phyr");
|
||||
ck_tile::index_t LeftPx = arg_parser.get_int("pwx");
|
||||
ck_tile::index_t RightPx = arg_parser.get_int("pwxr");
|
||||
|
||||
bool verify = arg_parser.get_int("verify") != 0;
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int log_level = arg_parser.get_int("log");
|
||||
|
||||
ck_tile::index_t Zs = (Z - 1) * Dz + 1;
|
||||
ck_tile::index_t Ys = (Y - 1) * Dy + 1;
|
||||
ck_tile::index_t Xs = (X - 1) * Dx + 1;
|
||||
ck_tile::index_t Do = (D + LeftPz + RightPz - Zs) / Sz + 1;
|
||||
ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
|
||||
ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
|
||||
|
||||
std::cout << "Pooling 3D benchmark: " << KERNEL_NAME << std::endl;
|
||||
std::cout << " Input: NDHWC = " << N << "x" << D << "x" << H << "x" << W << "x" << C
|
||||
<< std::endl;
|
||||
std::cout << " Output: NDHWC = " << N << "x" << Do << "x" << Ho << "x" << Wo << "x" << C
|
||||
<< std::endl;
|
||||
std::cout << " Window: " << Z << "x" << Y << "x" << X << ", stride: " << Sz << "x" << Sy
|
||||
<< "x" << Sx << ", dilation: " << Dz << "x" << Dy << "x" << Dx << std::endl;
|
||||
|
||||
ck_tile::HostTensor<InDataType> h_in({N, D, H, W, C});
|
||||
ck_tile::HostTensor<OutDataType> h_out({N, Do, Ho, Wo, C});
|
||||
ck_tile::HostTensor<OutDataType> h_out_ref({N, Do, Ho, Wo, C});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_index({N, Do, Ho, Wo, C});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_ref_index({N, Do, Ho, Wo, C});
|
||||
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
|
||||
|
||||
ck_tile::DeviceMem d_in(h_in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out(h_out.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_index(h_out_index.get_element_space_size_in_bytes());
|
||||
|
||||
d_in.ToDevice(h_in.data());
|
||||
d_out.SetZero();
|
||||
d_out_index.SetZero();
|
||||
|
||||
auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
|
||||
auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
|
||||
auto input_strides =
|
||||
ck_tile::make_tuple(D * H * W * C, H * W * C, W * C, C, ck_tile::index_t{1});
|
||||
auto output_strides =
|
||||
ck_tile::make_tuple(Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, ck_tile::index_t{1});
|
||||
auto window_lengths = ck_tile::make_tuple(Z, Y, X);
|
||||
auto window_strides = ck_tile::make_tuple(Sz, Sy, Sx);
|
||||
auto window_dilations = ck_tile::make_tuple(Dz, Dy, Dx);
|
||||
auto input_left_pads = ck_tile::make_tuple(LeftPz, LeftPy, LeftPx);
|
||||
auto input_right_pads = ck_tile::make_tuple(RightPz, RightPy, RightPx);
|
||||
|
||||
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
|
||||
d_in.GetDeviceBuffer(),
|
||||
d_out.GetDeviceBuffer(),
|
||||
d_out_index.GetDeviceBuffer(),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
ck_tile::stream_config stream{nullptr, true, log_level, warmup, repeat};
|
||||
|
||||
float latency = 0;
|
||||
try
|
||||
{
|
||||
latency = launch_selected_kernel(host_args, stream);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Kernel launch failed: " << e.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
size_t bytes_read = static_cast<size_t>(N) * D * H * W * C * sizeof(InDataType);
|
||||
size_t bytes_written = static_cast<size_t>(N) * Do * Ho * Wo * C * sizeof(OutDataType);
|
||||
float bandwidth = (bytes_read + bytes_written) / (latency * 1e-3f) / 1e9f;
|
||||
|
||||
std::cout << " Latency: " << latency << " ms" << std::endl;
|
||||
std::cout << " Bandwidth: " << bandwidth << " GB/s" << std::endl;
|
||||
|
||||
if(verify)
|
||||
{
|
||||
d_out.FromDevice(h_out.data());
|
||||
d_out_index.FromDevice(h_out_index.data());
|
||||
|
||||
auto kernel_args =
|
||||
ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
|
||||
h_in.data(),
|
||||
h_out_ref.data(),
|
||||
h_out_ref_index.data(),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
ck_tile::reference_pool3d<InDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
decltype(input_shape),
|
||||
decltype(window_lengths),
|
||||
SelectedKernel::kOutputIndex>(
|
||||
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
|
||||
|
||||
bool pass_value =
|
||||
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-3, 1e-3);
|
||||
std::cout << " Verification: " << (pass_value ? "PASS" : "FAIL") << std::endl;
|
||||
|
||||
if(SelectedKernel::kOutputIndex)
|
||||
{
|
||||
bool pass_index = ck_tile::check_err(
|
||||
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
|
||||
std::cout << " Index verification: " << (pass_index ? "PASS" : "FAIL")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return benchmark_pooling<POOLING_DIM>(argc, argv); }
|
||||
52
tile_engine/ops/pooling/pooling_common.hpp
Normal file
52
tile_engine/ops/pooling/pooling_common.hpp
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/pooling.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Kernel trait parameters for pooling tile_engine configurations
|
||||
struct PoolingKernelTraits
|
||||
{
|
||||
std::string reduce_op; // "max", "min", or "avg"
|
||||
bool output_index; // Whether to output indices (max pooling)
|
||||
bool propagate_nan; // Whether to propagate NaN values
|
||||
bool cross_warp; // Whether cross-warp reduction is used
|
||||
|
||||
std::string to_string() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << reduce_op << "_" << (output_index ? "idx" : "noidx") << "_"
|
||||
<< (propagate_nan ? "nan" : "nonan") << "_"
|
||||
<< (cross_warp ? "crosswarp" : "nocrosswarp");
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Extract traits from a kernel name string
|
||||
inline PoolingKernelTraits extract_pooling_traits_from_name(const std::string& name)
|
||||
{
|
||||
PoolingKernelTraits traits;
|
||||
if(name.find("max") != std::string::npos)
|
||||
traits.reduce_op = "max";
|
||||
else if(name.find("min") != std::string::npos)
|
||||
traits.reduce_op = "min";
|
||||
else
|
||||
traits.reduce_op = "avg";
|
||||
traits.output_index =
|
||||
(name.find("idx") != std::string::npos) && (name.find("noidx") == std::string::npos);
|
||||
traits.propagate_nan =
|
||||
(name.find("nan") != std::string::npos) && (name.find("nonan") == std::string::npos);
|
||||
traits.cross_warp = (name.find("crosswarp") != std::string::npos) &&
|
||||
(name.find("nocrosswarp") == std::string::npos);
|
||||
return traits;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
551
tile_engine/ops/pooling/pooling_instance_builder.py
Normal file
551
tile_engine/ops/pooling/pooling_instance_builder.py
Normal file
@@ -0,0 +1,551 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Pooling kernel instance builder for tile_engine.
|
||||
|
||||
Generates C++ kernel headers for pooling operations with specific tile
|
||||
configurations and trait combinations.
|
||||
|
||||
Usage:
|
||||
--list_kernels: List valid kernel configurations
|
||||
--gen_single: Generate a single kernel header
|
||||
--gen_individual: Generate all kernel headers
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import itertools
|
||||
import multiprocessing
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from pooling_validation_utils import (
|
||||
is_tile_config_valid,
|
||||
is_trait_combination_valid,
|
||||
get_dtype_string,
|
||||
get_reduce_op_string,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PoolingKernelBuilder:
|
||||
def __init__(self, working_path, datatype, config_json=None):
|
||||
self.working_path = Path(working_path)
|
||||
self.datatype = datatype
|
||||
self.config_json = config_json
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
self.working_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load configuration
|
||||
if config_json and os.path.exists(config_json):
|
||||
with open(config_json, "r") as f:
|
||||
self.config = json.load(f)
|
||||
else:
|
||||
self.config = self._get_default_config()
|
||||
|
||||
def _get_default_config(self):
|
||||
"""Return default configuration if no config file is provided"""
|
||||
return {
|
||||
"tile_config": {
|
||||
"block_m": {"values": [64,128,256]},
|
||||
"block_n": {"values": [1,2]},
|
||||
"warp_m": {"values": [1]},
|
||||
"warp_n": {"values": [1]},
|
||||
"warp_tile_m": {"values": [128]},
|
||||
"warp_tile_n": {"values": [1]},
|
||||
"thread_tile_m": {"values": [1,2,4]},
|
||||
"thread_tile_n": {"values": [1]},
|
||||
},
|
||||
"trait_config": {
|
||||
"reduce_op": {"values": ["max", "min", "avg"]},
|
||||
"output_index": {"values": [True, False]},
|
||||
"propagate_nan": {"values": [True, False]},
|
||||
"pooling_dim": {"values": ["2d", "3d"]},
|
||||
},
|
||||
}
|
||||
|
||||
def _get_tile_configs(self, fast_mode=False):
|
||||
"""Get tile configurations from config"""
|
||||
if "tile_config" not in self.config:
|
||||
return []
|
||||
|
||||
tile_config = self.config["tile_config"]
|
||||
|
||||
block_m_values = tile_config.get("block_m", {}).get("values", [64,128,256])
|
||||
block_n_values = tile_config.get("block_n", {}).get("values", [1,2])
|
||||
warp_m_values = tile_config.get("warp_m", {}).get("values", [1])
|
||||
warp_n_values = tile_config.get("warp_n", {}).get("values", [1])
|
||||
warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [128])
|
||||
warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [1])
|
||||
thread_tile_m_values = tile_config.get("thread_tile_m", {}).get("values", [1,2,4])
|
||||
thread_tile_n_values = tile_config.get("thread_tile_n", {}).get("values", [1])
|
||||
|
||||
configs = []
|
||||
for block_m in block_m_values:
|
||||
for block_n in block_n_values:
|
||||
for warp_m in warp_m_values:
|
||||
for warp_n in warp_n_values:
|
||||
for warp_tile_m in warp_tile_m_values:
|
||||
for warp_tile_n in warp_tile_n_values:
|
||||
for thread_tile_m in thread_tile_m_values:
|
||||
for thread_tile_n in thread_tile_n_values:
|
||||
if self._validate_tile_config(
|
||||
block_m,
|
||||
block_n,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
thread_tile_m,
|
||||
thread_tile_n,
|
||||
fast_mode=fast_mode,
|
||||
):
|
||||
configs.append(
|
||||
{
|
||||
"block_m": block_m,
|
||||
"block_n": block_n,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"thread_tile_m": thread_tile_m,
|
||||
"thread_tile_n": thread_tile_n,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
def _validate_tile_config(
|
||||
self,
|
||||
block_m,
|
||||
block_n,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
thread_tile_m,
|
||||
thread_tile_n,
|
||||
fast_mode=False,
|
||||
):
|
||||
"""Validate tile configuration via pooling_validation_utils."""
|
||||
return is_tile_config_valid(
|
||||
block_m,
|
||||
block_n,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
thread_tile_m,
|
||||
thread_tile_n,
|
||||
self.datatype,
|
||||
self.datatype,
|
||||
fast_mode=fast_mode,
|
||||
)
|
||||
|
||||
def _generate_trait_combinations(self):
|
||||
"""Generate all combinations of traits"""
|
||||
if "trait_config" not in self.config:
|
||||
return [("max", True, False, "2d")]
|
||||
|
||||
trait_config = self.config["trait_config"]
|
||||
|
||||
reduce_ops = trait_config.get("reduce_op", {}).get("values", ["min","max","avg"])
|
||||
output_indices = trait_config.get("output_index", {}).get("values", [True, False])
|
||||
propagate_nans = trait_config.get("propagate_nan", {}).get("values", [True, False])
|
||||
pooling_dims = trait_config.get("pooling_dim", {}).get("values", ["2d", "3d"])
|
||||
|
||||
all_combinations = list(
|
||||
itertools.product(reduce_ops, output_indices, propagate_nans, pooling_dims)
|
||||
)
|
||||
|
||||
# Filter valid combinations
|
||||
combinations = []
|
||||
for combo in all_combinations:
|
||||
reduce_op, output_index, propagate_nan, pooling_dim = combo
|
||||
if is_trait_combination_valid(
|
||||
reduce_op, output_index, propagate_nan, pooling_dim
|
||||
):
|
||||
combinations.append(combo)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Skipping unsupported trait combination: {reduce_op}-{output_index}-{propagate_nan}-{pooling_dim}"
|
||||
)
|
||||
|
||||
return combinations
|
||||
|
||||
def _get_dtype_string(self):
|
||||
"""Get C++ type string for datatype."""
|
||||
return get_dtype_string(self.datatype)
|
||||
|
||||
def _get_reduce_op_string(self, reduce_op):
|
||||
"""Get C++ reduce op type string."""
|
||||
return get_reduce_op_string(reduce_op)
|
||||
|
||||
def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True):
|
||||
"""Generate a single kernel instance header"""
|
||||
reduce_op, output_index, propagate_nan, pooling_dim = trait_combo
|
||||
|
||||
# Create kernel name
|
||||
kernel_name = (
|
||||
f"pool_{self.datatype}_{pooling_dim}_{reduce_op}_"
|
||||
f"{'idx' if output_index else 'noidx'}_"
|
||||
f"{'nan' if propagate_nan else 'nonan'}"
|
||||
)
|
||||
|
||||
# Create tile configuration string
|
||||
tile_str = (
|
||||
f"{tile_config['block_m']}x{tile_config['block_n']}_"
|
||||
f"{tile_config['warp_m']}x{tile_config['warp_n']}_"
|
||||
f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}_"
|
||||
f"{tile_config['thread_tile_m']}x{tile_config['thread_tile_n']}"
|
||||
)
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
# Determine types
|
||||
in_type = self._get_dtype_string()
|
||||
out_type = in_type
|
||||
compute_type = "float" # Always use float for computation
|
||||
index_type = "ck_tile::index_t"
|
||||
reduce_op_type = self._get_reduce_op_string(reduce_op)
|
||||
|
||||
output_index_str = "true" if output_index else "false"
|
||||
propagate_nan_str = "true" if propagate_nan else "false"
|
||||
|
||||
# Generate 2D or 3D specific code
|
||||
if pooling_dim == "2d":
|
||||
tensor_shape_type = "ck_tile::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>"
|
||||
window_shape_type = "ck_tile::tuple<ck_tile::index_t, ck_tile::index_t>"
|
||||
window_rank = 2
|
||||
else:
|
||||
tensor_shape_type = "ck_tile::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>"
|
||||
window_shape_type = (
|
||||
"ck_tile::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>"
|
||||
)
|
||||
window_rank = 3
|
||||
|
||||
pragma_line = "#pragma once\n" if is_header else ""
|
||||
instance_code = f"""// Generated kernel instance for {kernel_name}
|
||||
{pragma_line}
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/pooling.hpp"
|
||||
|
||||
using InDataType = {in_type};
|
||||
using OutDataType = {out_type};
|
||||
using ComputeDataType = {compute_type};
|
||||
using IndexDataType = {index_type};
|
||||
using ReduceOpType = {reduce_op_type};
|
||||
|
||||
using TensorShape = {tensor_shape_type};
|
||||
using WindowShape = {window_shape_type};
|
||||
|
||||
// Kernel name for display
|
||||
constexpr const char* KERNEL_NAME = "{kernel_name}";
|
||||
constexpr int POOLING_DIM = {window_rank};
|
||||
|
||||
// Wrapper for simplified launch interface
|
||||
struct SelectedKernel {{
|
||||
// Tile configuration - PoolShape parameters
|
||||
static constexpr ck_tile::index_t Block_M = {tile_config["block_m"]};
|
||||
static constexpr ck_tile::index_t Block_N = {tile_config["block_n"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
|
||||
static constexpr ck_tile::index_t WarpTile_M = {tile_config["warp_tile_m"]};
|
||||
static constexpr ck_tile::index_t WarpTile_N = {tile_config["warp_tile_n"]};
|
||||
static constexpr ck_tile::index_t ThreadTile_M = {tile_config["thread_tile_m"]};
|
||||
static constexpr ck_tile::index_t ThreadTile_N = {tile_config["thread_tile_n"]};
|
||||
|
||||
// Traits
|
||||
static constexpr bool kOutputIndex = {output_index_str};
|
||||
static constexpr bool kPropagateNan = {propagate_nan_str};
|
||||
|
||||
// Pool shape
|
||||
using BlockWarps = ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N>;
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using WarpTile = ck_tile::sequence<WarpTile_M, WarpTile_N>;
|
||||
using ThreadTile = ck_tile::sequence<ThreadTile_M, ThreadTile_N>;
|
||||
|
||||
using PoolShapeType = ck_tile::PoolShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
|
||||
// Problem and kernel types
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
kOutputIndex,
|
||||
kPropagateNan,
|
||||
PoolShapeType>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
static float launch(ck_tile::PoolHostArgs<TensorShape, WindowShape>& args,
|
||||
const ck_tile::stream_config& stream) {{
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
auto kernel_args = Kernel::MakeKernelArgs(args);
|
||||
|
||||
if (!Kernel::IsSupportedArgument(kernel_args)) {{
|
||||
throw std::runtime_error(
|
||||
std::string("Unsupported arguments for pooling kernel: ") + KERNEL_NAME);
|
||||
}}
|
||||
|
||||
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
std::cout << "Launching pooling kernel: " << KERNEL_NAME << "\\n"
|
||||
<< " grid_size: " << kGridSize << ", block_size: " << kBlockSize
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, kGridSize, kBlockSize, 0, kernel_args));
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
return kernel_name, instance_code
|
||||
|
||||
def write_kernel_list(self):
|
||||
"""Write kernel list to file for CMake to read"""
|
||||
tile_configs = self._get_tile_configs(fast_mode=False)
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
kernel_list = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
reduce_op, output_index, propagate_nan, pooling_dim = trait_combo
|
||||
|
||||
kernel_name = (
|
||||
f"pool_{self.datatype}_{pooling_dim}_{reduce_op}_"
|
||||
f"{'idx' if output_index else 'noidx'}_"
|
||||
f"{'nan' if propagate_nan else 'nonan'}"
|
||||
)
|
||||
|
||||
tile_str = (
|
||||
f"{tile_config['block_m']}x{tile_config['block_n']}_"
|
||||
f"{tile_config['warp_m']}x{tile_config['warp_n']}_"
|
||||
f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}_"
|
||||
f"{tile_config['thread_tile_m']}x{tile_config['thread_tile_n']}"
|
||||
)
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
trait_str = (
|
||||
f"{reduce_op}_"
|
||||
f"{'true' if output_index else 'false'}_"
|
||||
f"{'true' if propagate_nan else 'false'}_"
|
||||
f"{pooling_dim}"
|
||||
)
|
||||
|
||||
kernel_list.append(
|
||||
{
|
||||
"name": kernel_name,
|
||||
"tile_config": tile_config,
|
||||
"trait_combo": trait_combo,
|
||||
"tile_str": tile_str,
|
||||
"trait_str": trait_str,
|
||||
}
|
||||
)
|
||||
|
||||
# Write kernel count
|
||||
with open(self.working_path / "pool_kernel_count.txt", "w") as f:
|
||||
f.write(str(len(kernel_list)))
|
||||
|
||||
# Write kernel list
|
||||
with open(self.working_path / "pool_kernel_list.txt", "w") as f:
|
||||
for kernel in kernel_list:
|
||||
f.write(
|
||||
f"{kernel['name']}|{kernel['tile_str']}|{kernel['trait_str']}\n"
|
||||
)
|
||||
|
||||
print(f"Listed {len(kernel_list)} kernel configurations")
|
||||
|
||||
def generate_individual(self, num_workers=None):
|
||||
"""Generate individual kernel files with parallel processing"""
|
||||
if num_workers is None:
|
||||
num_workers = min(multiprocessing.cpu_count(), 8)
|
||||
|
||||
tile_configs = self._get_tile_configs()
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
work_items = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
work_items.append(
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
self.working_path,
|
||||
self.datatype,
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
|
||||
kernel_list = []
|
||||
completed = 0
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=num_workers
|
||||
) as executor:
|
||||
future_to_item = {
|
||||
executor.submit(_generate_single_kernel_individual, item): item
|
||||
for item in work_items
|
||||
}
|
||||
|
||||
for future in concurrent.futures.as_completed(future_to_item):
|
||||
completed += 1
|
||||
if completed % 10 == 0 or completed == len(work_items):
|
||||
print(
|
||||
f" Progress: {completed}/{len(work_items)} kernels generated"
|
||||
)
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
kernel_list.append(result)
|
||||
except Exception as exc:
|
||||
item = future_to_item[future]
|
||||
print(f"Kernel generation failed for {item}: {exc}")
|
||||
|
||||
kernel_list.sort(key=lambda x: x[0])
|
||||
print(
|
||||
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
|
||||
)
|
||||
|
||||
def run(self, num_workers=None):
|
||||
"""Run the builder to generate individual kernel files"""
|
||||
self.generate_individual(num_workers)
|
||||
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
tile_config, trait_combo, working_path, datatype = work_item
|
||||
|
||||
builder = PoolingKernelBuilder(working_path, datatype)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo
|
||||
)
|
||||
|
||||
header_file = working_path / f"pooling_single_{kernel_name}.hpp"
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
return (kernel_name, trait_combo, tile_config)
|
||||
except Exception as e:
|
||||
print(f"Error generating individual kernel: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Pooling kernel instance builder for tile_engine"
|
||||
)
|
||||
parser.add_argument("--working_path", required=True, help="Working directory path")
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
choices=["fp8", "fp16", "bf16", "fp32"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument("--config_json", help="Configuration JSON file")
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_individual", action="store_true", help="Generate individual kernel files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_single", action="store_true", help="Generate a single kernel file"
|
||||
)
|
||||
parser.add_argument("--kernel_name", help="Kernel name for single generation")
|
||||
parser.add_argument(
|
||||
"--tile_config", help="Tile configuration string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trait_combo", help="Trait combination string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_kernels",
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
builder = PoolingKernelBuilder(args.working_path, args.datatype, args.config_json)
|
||||
|
||||
if args.list_kernels:
|
||||
builder.write_kernel_list()
|
||||
elif args.gen_single:
|
||||
if not args.kernel_name or not args.tile_config or not args.trait_combo:
|
||||
parser.error(
|
||||
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
|
||||
)
|
||||
|
||||
# Parse tile config: "block_mx block_n_warp_mxwarp_n_warp_tile_mxwarp_tile_n_thread_tile_mxthread_tile_n"
|
||||
tile_parts = args.tile_config.split("_")
|
||||
block_dims = tile_parts[0].split("x")
|
||||
warp_dims = tile_parts[1].split("x")
|
||||
warp_tile_dims = tile_parts[2].split("x")
|
||||
thread_tile_dims = tile_parts[3].split("x")
|
||||
|
||||
tile_config = {
|
||||
"block_m": int(block_dims[0]),
|
||||
"block_n": int(block_dims[1]),
|
||||
"warp_m": int(warp_dims[0]),
|
||||
"warp_n": int(warp_dims[1]),
|
||||
"warp_tile_m": int(warp_tile_dims[0]),
|
||||
"warp_tile_n": int(warp_tile_dims[1]),
|
||||
"thread_tile_m": int(thread_tile_dims[0]),
|
||||
"thread_tile_n": int(thread_tile_dims[1]),
|
||||
}
|
||||
|
||||
# Parse trait combo: "reduce_op_output_index_propagate_nan_pooling_dim"
|
||||
trait_parts = args.trait_combo.split("_")
|
||||
trait_combo = (
|
||||
trait_parts[0], # reduce_op
|
||||
trait_parts[1].lower() == "true", # output_index
|
||||
trait_parts[2].lower() == "true", # propagate_nan
|
||||
trait_parts[3], # pooling_dim
|
||||
)
|
||||
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo
|
||||
)
|
||||
|
||||
header_file = builder.working_path / f"pooling_single_{kernel_name}.hpp"
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
print(f"Generated {header_file}")
|
||||
|
||||
elif args.gen_individual:
|
||||
builder.run(args.num_workers)
|
||||
else:
|
||||
parser.error(
|
||||
"Must specify one of: --list_kernels, --gen_individual, or --gen_single"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
487
tile_engine/ops/pooling/pooling_validation_utils.py
Normal file
487
tile_engine/ops/pooling/pooling_validation_utils.py
Normal file
@@ -0,0 +1,487 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Validation utilities for pooling tile_engine configurations.
|
||||
|
||||
Validates tile configurations, trait combinations, and datatype support for
|
||||
pooling kernels. Modelled after gemm_validation_utils.py — each constraint
|
||||
from the CK PoolShape / PoolKernel static_asserts is mirrored here so that
|
||||
invalid configs are rejected at code-generation time rather than at compile
|
||||
or runtime.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hardware constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Default warp size (wave64 for CDNA architectures)
|
||||
WARP_SIZE = 64
|
||||
MAX_BLOCK_SIZE = 1024 # Maximum threads per workgroup on AMD GPUs
|
||||
MAX_LDS_BYTES = 65536 # 64 KB LDS per workgroup
|
||||
|
||||
def get_warp_size_for_gpu(gpu_target: str) -> int:
|
||||
"""Get the warp size for a given GPU target.
|
||||
|
||||
CDNA architectures (gfx9xx) use WAVE64 (64 threads per wavefront).
|
||||
RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use WAVE32 (32 threads per wavefront).
|
||||
"""
|
||||
if gpu_target.startswith("gfx9"):
|
||||
return 64 # CDNA - WAVE64
|
||||
return 32 # RDNA and others - WAVE32
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Datatype helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ELEMENT_SIZE_MAP = {
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int8": 1,
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"int4": 0.5,
|
||||
"int32": 4,
|
||||
"fp32": 4,
|
||||
"fp64": 8,
|
||||
}
|
||||
|
||||
DTYPE_STRING_MAP = {
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float",
|
||||
"fp64": "double",
|
||||
}
|
||||
|
||||
SUPPORTED_DATATYPES = list(DTYPE_STRING_MAP.keys())
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reduce-op helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
REDUCE_OP_STRING_MAP = {
|
||||
"max": "ck_tile::ReduceOp::Max",
|
||||
"min": "ck_tile::ReduceOp::Min",
|
||||
"avg": "ck_tile::ReduceOp::Add",
|
||||
}
|
||||
|
||||
SUPPORTED_REDUCE_OPS = list(REDUCE_OP_STRING_MAP.keys())
|
||||
|
||||
SUPPORTED_POOLING_DIMS = ("2d", "3d")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public helper functions (used by the instance builder)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def element_size(datatype: str) -> float:
|
||||
"""Return the byte-width of a single element for *datatype*."""
|
||||
datatype = datatype.lower()
|
||||
if datatype not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported data type: '{datatype}'. "
|
||||
f"Supported: {list(ELEMENT_SIZE_MAP.keys())}"
|
||||
)
|
||||
return ELEMENT_SIZE_MAP[datatype]
|
||||
|
||||
|
||||
def get_dtype_string(datatype: str) -> str:
|
||||
"""Return the C++ type string (e.g. ``ck_tile::fp16_t``) for *datatype*."""
|
||||
return DTYPE_STRING_MAP.get(datatype, "float")
|
||||
|
||||
|
||||
def get_reduce_op_string(reduce_op: str) -> str:
|
||||
"""Return the C++ ReduceOp enumerator string for *reduce_op*."""
|
||||
return REDUCE_OP_STRING_MAP.get(reduce_op, "ck_tile::ReduceOp::Max")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Individual tile-config validators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def validate_positivity(
|
||||
block_m: int,
|
||||
block_n: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
thread_tile_m: int,
|
||||
thread_tile_n: int,
|
||||
) -> Tuple[bool, str]:
|
||||
"""All tile parameters must be positive integers."""
|
||||
params = {
|
||||
"block_m": block_m,
|
||||
"block_n": block_n,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"thread_tile_m": thread_tile_m,
|
||||
"thread_tile_n": thread_tile_n,
|
||||
}
|
||||
for name, val in params.items():
|
||||
if val <= 0:
|
||||
return False, f"{name} ({val}) must be > 0"
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_power_of_two(
|
||||
block_m: int,
|
||||
block_n: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
thread_tile_m: int,
|
||||
thread_tile_n: int,
|
||||
) -> Tuple[bool, str]:
|
||||
"""All tile parameters should be powers of two for correct GPU addressing."""
|
||||
params = {
|
||||
"block_m": block_m,
|
||||
"block_n": block_n,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"thread_tile_m": thread_tile_m,
|
||||
"thread_tile_n": thread_tile_n,
|
||||
}
|
||||
for name, val in params.items():
|
||||
if val > 0 and (val & (val - 1)) != 0:
|
||||
return False, f"{name} ({val}) is not a power of two"
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_thread_tile_alignment(
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
thread_tile_m: int,
|
||||
thread_tile_n: int,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Mirrors pool_shape.hpp:
|
||||
static_assert(Warp_M % ThreadTile_M == 0);
|
||||
static_assert(Warp_N % ThreadTile_N == 0);
|
||||
"""
|
||||
if warp_tile_m % thread_tile_m != 0:
|
||||
return (
|
||||
False,
|
||||
f"warp_tile_m ({warp_tile_m}) must be divisible by "
|
||||
f"thread_tile_m ({thread_tile_m})",
|
||||
)
|
||||
if warp_tile_n % thread_tile_n != 0:
|
||||
return (
|
||||
False,
|
||||
f"warp_tile_n ({warp_tile_n}) must be divisible by "
|
||||
f"thread_tile_n ({thread_tile_n})",
|
||||
)
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_warp_thread_distribution(
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
thread_tile_m: int,
|
||||
thread_tile_n: int,
|
||||
warp_size: int = WARP_SIZE,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Mirrors pool_shape.hpp:
|
||||
static_assert((Warp_M * Warp_N / ThreadTile_M / ThreadTile_N)
|
||||
% get_warp_size() == 0);
|
||||
"""
|
||||
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
|
||||
if threads_per_warp % warp_size != 0:
|
||||
return (
|
||||
False,
|
||||
f"(warp_tile_m * warp_tile_n) / (thread_tile_m * thread_tile_n) = "
|
||||
f"{threads_per_warp} is not a multiple of warp_size ({warp_size})",
|
||||
)
|
||||
return True, ""
|
||||
|
||||
|
||||
def _compute_warp_size_scale_factors(
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
thread_tile_m: int,
|
||||
thread_tile_n: int,
|
||||
warp_size: int = WARP_SIZE,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Reproduce the WarpSizeScaleFactor_M / _N logic from pool_shape.hpp.
|
||||
"""
|
||||
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
|
||||
scale = threads_per_warp // warp_size
|
||||
|
||||
if warp_tile_m // thread_tile_m > warp_tile_n // thread_tile_n:
|
||||
return scale, 1
|
||||
return 1, scale
|
||||
|
||||
|
||||
def validate_block_tile_coverage(
|
||||
block_m: int,
|
||||
block_n: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
thread_tile_m: int,
|
||||
thread_tile_n: int,
|
||||
warp_size: int = WARP_SIZE,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Mirrors pool_shape.hpp:
|
||||
static_assert((Block_M * WarpSizeScaleFactor_M) %
|
||||
(WarpPerBlock_M * Warp_M) == 0);
|
||||
static_assert((Block_N * WarpSizeScaleFactor_N) %
|
||||
(WarpPerBlock_N * Warp_N) == 0);
|
||||
"""
|
||||
sf_m, sf_n = _compute_warp_size_scale_factors(
|
||||
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n, warp_size
|
||||
)
|
||||
|
||||
if (block_m * sf_m) % (warp_m * warp_tile_m) != 0:
|
||||
return (
|
||||
False,
|
||||
f"block_m*ScaleFactor_M ({block_m}*{sf_m}={block_m * sf_m}) must be "
|
||||
f"divisible by warp_m*warp_tile_m ({warp_m}*{warp_tile_m}"
|
||||
f"={warp_m * warp_tile_m})",
|
||||
)
|
||||
if (block_n * sf_n) % (warp_n * warp_tile_n) != 0:
|
||||
return (
|
||||
False,
|
||||
f"block_n*ScaleFactor_N ({block_n}*{sf_n}={block_n * sf_n}) must be "
|
||||
f"divisible by warp_n*warp_tile_n ({warp_n}*{warp_tile_n}"
|
||||
f"={warp_n * warp_tile_n})",
|
||||
)
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_block_size(
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_size: int = WARP_SIZE,
|
||||
) -> Tuple[bool, str]:
|
||||
"""BlockSize = warp_size * warp_m * warp_n must be <= MAX_BLOCK_SIZE."""
|
||||
block_size = warp_size * warp_m * warp_n
|
||||
if block_size > MAX_BLOCK_SIZE:
|
||||
return (
|
||||
False,
|
||||
f"BlockSize ({block_size} = {warp_size}*{warp_m}*{warp_n}) "
|
||||
f"exceeds maximum ({MAX_BLOCK_SIZE})",
|
||||
)
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_vector_load_alignment(
|
||||
block_m: int,
|
||||
thread_tile_m: int,
|
||||
in_datatype: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
The M-dimension thread-tile determines the contiguous vector load width.
|
||||
It must produce a load whose byte-width divides 16 bytes (max global
|
||||
vector load width on AMD GPUs) and is at least 1 element wide.
|
||||
"""
|
||||
elem_bytes = element_size(in_datatype)
|
||||
load_bytes = thread_tile_m * elem_bytes
|
||||
if load_bytes > 16:
|
||||
return (
|
||||
False,
|
||||
f"thread_tile_m ({thread_tile_m}) * element_size({in_datatype}, "
|
||||
f"{elem_bytes}B) = {load_bytes}B exceeds 16B max vector load",
|
||||
)
|
||||
if 16 % load_bytes != 0 and load_bytes % 16 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Vector load width ({load_bytes}B) is not a divisor of 16B",
|
||||
)
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_repeat_factors(
|
||||
block_m: int,
|
||||
block_n: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
thread_tile_m: int,
|
||||
thread_tile_n: int,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Repeat_M and Repeat_N from pool_shape.hpp must be >= 1. They are the
|
||||
number of tile iterations each warp performs within the block.
|
||||
"""
|
||||
sf_m, sf_n = _compute_warp_size_scale_factors(
|
||||
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n
|
||||
)
|
||||
repeat_m = (block_m * sf_m) // (warp_m * warp_tile_m)
|
||||
repeat_n = (block_n * sf_n) // (warp_n * warp_tile_n)
|
||||
if repeat_m < 1:
|
||||
return False, f"Repeat_M ({repeat_m}) must be >= 1"
|
||||
if repeat_n < 1:
|
||||
return False, f"Repeat_N ({repeat_n}) must be >= 1"
|
||||
return True, ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Comprehensive tile-config validation (entry point)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def is_tile_config_valid(
|
||||
block_m: int,
|
||||
block_n: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
thread_tile_m: int,
|
||||
thread_tile_n: int,
|
||||
in_datatype: str,
|
||||
out_datatype: str,
|
||||
fast_mode: bool = False,
|
||||
gpu_target: str = "gfx90a",
|
||||
) -> bool:
|
||||
"""
|
||||
Comprehensive pooling tile configuration validation.
|
||||
|
||||
When *fast_mode* is True only cheap sanity checks are performed (useful
|
||||
for the ``--list_kernels`` path). Full mode mirrors every
|
||||
``static_assert`` in ``pool_shape.hpp``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
block_m, block_n : Block tile dimensions (M = output elems, N = window).
|
||||
warp_m, warp_n : Warps per block along each dimension.
|
||||
warp_tile_m, warp_tile_n : Tile processed per warp.
|
||||
thread_tile_m, thread_tile_n : Contiguous elements per thread.
|
||||
in_datatype : Input element type (e.g. ``"fp16"``).
|
||||
out_datatype : Output element type.
|
||||
fast_mode : Skip expensive checks when True.
|
||||
"""
|
||||
all_params = (
|
||||
block_m, block_n, warp_m, warp_n,
|
||||
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n,
|
||||
)
|
||||
|
||||
# --- Positivity (always) ---
|
||||
ok, err = validate_positivity(*all_params)
|
||||
if not ok:
|
||||
logger.debug(f"Positivity check failed: {err}")
|
||||
return False
|
||||
|
||||
# --- Thread-tile alignment (always) ---
|
||||
ok, err = validate_thread_tile_alignment(
|
||||
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n
|
||||
)
|
||||
if not ok:
|
||||
logger.debug(f"Thread tile alignment failed: {err}")
|
||||
return False
|
||||
|
||||
if fast_mode:
|
||||
return True
|
||||
|
||||
# Get the warp size for this GPU target
|
||||
warp_size = get_warp_size_for_gpu(gpu_target)
|
||||
|
||||
# --- Power-of-two ---
|
||||
ok, err = validate_power_of_two(*all_params)
|
||||
if not ok:
|
||||
logger.debug(f"Power-of-two check failed: {err}")
|
||||
return False
|
||||
|
||||
# --- Warp-thread distribution ---
|
||||
ok, err = validate_warp_thread_distribution(
|
||||
warp_tile_m, warp_tile_n, thread_tile_m, thread_tile_n, warp_size
|
||||
)
|
||||
if not ok:
|
||||
logger.debug(f"Warp thread distribution failed: {err}")
|
||||
return False
|
||||
|
||||
# --- Block-tile coverage ---
|
||||
ok, err = validate_block_tile_coverage(*all_params, warp_size=warp_size)
|
||||
if not ok:
|
||||
logger.debug(f"Block tile coverage failed: {err}")
|
||||
return False
|
||||
|
||||
# --- Block size ---
|
||||
ok, err = validate_block_size(warp_m, warp_n, warp_size)
|
||||
if not ok:
|
||||
logger.debug(f"Block size check failed: {err}")
|
||||
return False
|
||||
|
||||
# --- Repeat factors ---
|
||||
ok, err = validate_repeat_factors(*all_params)
|
||||
if not ok:
|
||||
logger.debug(f"Repeat factor check failed: {err}")
|
||||
return False
|
||||
|
||||
# --- Vector load alignment ---
|
||||
ok, err = validate_vector_load_alignment(block_m, thread_tile_m, in_datatype)
|
||||
if not ok:
|
||||
logger.debug(f"Vector load alignment failed: {err}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trait-combination validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def is_trait_combination_valid(
|
||||
reduce_op: str,
|
||||
output_index: bool,
|
||||
propagate_nan: bool,
|
||||
pooling_dim: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate a pooling trait combination.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reduce_op : ``"max"``, ``"min"``, or ``"avg"``.
|
||||
output_index : Whether to output indices of the selected elements.
|
||||
propagate_nan: Whether to propagate NaN values through the reduction.
|
||||
pooling_dim : ``"2d"`` or ``"3d"``.
|
||||
"""
|
||||
if reduce_op not in SUPPORTED_REDUCE_OPS:
|
||||
logger.debug(f"Unsupported reduce_op: '{reduce_op}'")
|
||||
return False
|
||||
|
||||
if pooling_dim not in SUPPORTED_POOLING_DIMS:
|
||||
logger.debug(f"Invalid pooling dimension: '{pooling_dim}'")
|
||||
return False
|
||||
|
||||
# output_index only makes sense for max pooling (CK constraint)
|
||||
if output_index and reduce_op != "max":
|
||||
logger.debug(
|
||||
f"output_index=True is only supported for 'max' pooling, "
|
||||
f"not '{reduce_op}'"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Datatype validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def is_datatype_supported(datatype: str) -> bool:
|
||||
"""Return True if *datatype* is a known pooling datatype."""
|
||||
return datatype.lower() in ELEMENT_SIZE_MAP
|
||||
@@ -11,7 +11,7 @@ set(MULTI_REDUCE_VARIANTS "multiops_multiblock;multiops_threadwise" CACHE STRING
|
||||
function(build_multi_reduce_for_datatype datatype variant)
|
||||
# Filter GPU targets to only gfx942, and gfx950
|
||||
set(GPU_TARGETS "")
|
||||
set(DESIRED_TARGETS "gfx942;gfx950")
|
||||
set(DESIRED_TARGETS "gfx942;gfx950;gfx12-generic")
|
||||
set(VALID_VARIANTS "multiops_multiblock;multiops_threadwise")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
@@ -22,7 +22,7 @@ function(build_multi_reduce_for_datatype datatype variant)
|
||||
|
||||
# Skip compilation if no matching targets found
|
||||
if(NOT GPU_TARGETS)
|
||||
message(WARNING "Skipping Tile Engine for Multi Reduction Kernel: No supported GPU targets (gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(WARNING "Skipping Tile Engine for Multi Reduction Kernel: No supported GPU targets (gfx942, gfx950, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user