[CK_TILE] Add support for gfx12 in tile_engine for GEMM benchmarking (#2802)

* initial work on adding support of gfx12 in tile_engine for GEMM benchmarking
* add stage("Run TILE_ENGINE_GEMM Tests on gfx1201") to Jenkins config
* make tile_[m/n/k] validation arch dependent
This commit is contained in:
pmaybank
2025-09-17 17:59:01 +01:00
committed by GitHub
parent c2997f2b7f
commit 592d73ad73
6 changed files with 249 additions and 52 deletions

View File

@@ -13,38 +13,38 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
return()
endif()
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
# First split by underscore to get three groups
string(REPLACE "_" ";" config_groups ${tile_config})
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
# Parse tile dimensions
string(REPLACE "x" ";" tile_parts ${tile_dims})
list(GET tile_parts 0 tile_m)
list(GET tile_parts 1 tile_n)
list(GET tile_parts 2 tile_k)
# Parse warp dimensions
string(REPLACE "x" ";" warp_parts ${warp_dims})
list(GET warp_parts 0 warp_m)
list(GET warp_parts 1 warp_n)
list(GET warp_parts 2 warp_k)
# Parse warp tile dimensions
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
list(GET warp_tile_parts 0 warp_tile_m)
list(GET warp_tile_parts 1 warp_tile_n)
list(GET warp_tile_parts 2 warp_tile_k)
set(target_name "benchmark_gemm_${datatype}_${layout}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Generate the single instance header for this kernel
set(instance_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
OUTPUT ${instance_header}
@@ -60,27 +60,27 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
# Create the executable
add_executable(${target_name}
add_executable(${target_name}
${GEMM_SOURCE_DIR}/benchmark_gemm_single.cpp
${instance_header}
)
# Set GPU architectures
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL})
# Set compile definitions
target_compile_definitions(${target_name} PRIVATE
GEMM_SINGLE_INSTANCE_HPP="${instance_header}"
)
# Include directories
target_include_directories(${target_name} PRIVATE
${GEMM_SOURCE_DIR}
${working_path}
)
# Compile options
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
@@ -88,19 +88,19 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
--offload-compress
-include ${instance_header}
)
# Add to collection targets
add_dependencies(benchmark_gemm_all ${target_name})
add_dependencies(benchmark_gemm_${datatype} ${target_name})
add_dependencies(benchmark_gemm_${layout} ${target_name})
add_dependencies(benchmark_gemm_${datatype}_${layout} ${target_name})
# Add to trait-specific targets
string(REPLACE "_" ";" trait_parts ${trait})
list(GET trait_parts 0 pipeline)
list(GET trait_parts 1 epilogue)
list(GET trait_parts 2 scheduler)
add_dependencies(benchmark_gemm_${pipeline} ${target_name})
add_dependencies(benchmark_gemm_${epilogue} ${target_name})
add_dependencies(benchmark_gemm_${scheduler} ${target_name})
@@ -109,13 +109,13 @@ endfunction()
# Function to build individual GEMM targets
function(build_individual_gemm_targets datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Choose config file
# Priority order:
# 1. Environment variable GEMM_CONFIG_FILE
# 2. CMake variable GEMM_CONFIG_FILE
# 2. CMake variable GEMM_CONFIG_FILE
# 3. Default based on layout
# Check environment variable first
if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "")
set(config_filename "$ENV{GEMM_CONFIG_FILE}")
@@ -130,12 +130,12 @@ function(build_individual_gemm_targets datatype layout)
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
message(STATUS " Using default config for layout ${layout}")
endif()
# Check if config file exists
if(NOT EXISTS ${json_blob})
message(FATAL_ERROR "Config file not found: ${json_blob}")
endif()
# Determine number of workers for parallel generation
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
@@ -147,17 +147,24 @@ function(build_individual_gemm_targets datatype layout)
set(num_workers 8)
endif()
endif()
# Generate individual kernel files using parallel version
message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
message(STATUS " Working path: ${working_path}")
message(STATUS " Config file: ${json_blob}")
message(STATUS " Python executable: ${Python3_EXECUTABLE}")
message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py")
# Create working directory first
file(MAKE_DIRECTORY ${working_path})
message(STATUS "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--list_kernels")
# First, just list the kernels (fast operation)
message(STATUS " Listing kernel configurations...")
execute_process(
@@ -172,11 +179,11 @@ function(build_individual_gemm_targets datatype layout)
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
endif()
# Read kernel count
if(EXISTS ${working_path}/gemm_kernel_count.txt)
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
@@ -185,7 +192,7 @@ function(build_individual_gemm_targets datatype layout)
else()
message(FATAL_ERROR "Kernel count file not found")
endif()
# Read kernel list and create targets
if(EXISTS ${working_path}/gemm_kernel_list.txt)
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
@@ -195,7 +202,7 @@ function(build_individual_gemm_targets datatype layout)
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Create individual target
create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
endforeach()
@@ -210,9 +217,9 @@ message(STATUS "GEMM_DATATYPE: ${GEMM_DATATYPE}")
message(STATUS "GEMM_LAYOUT: ${GEMM_LAYOUT}")
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, and gfx950
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -223,13 +230,13 @@ endforeach()
# Skip build if no matching targets found
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
# Enable parallel compilation optimizations
# Set up job pools for better parallel compilation control
set_property(GLOBAL PROPERTY JOB_POOLS
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4 # Limit heavy compilations to prevent OOM
compile_normal=16 # Allow more parallel normal compilations
)

View File

@@ -179,6 +179,11 @@ warp_tile_supported_combinations = {
[32, 32, 64],
],
},
"gfx1201": {
"fp16_fp16_fp16": [
[16, 16, 16],
],
},
}
# To Do: remove some unsupported combinations

View File

@@ -0,0 +1,102 @@
{
"problem": {
},
"tile_config": {
"tile_m": {
"values": [
256,
128,
64
]
},
"tile_n": {
"values": [
256,
128,
64
]
},
"tile_k": {
"values": [
256,
128,
64
]
},
"warp_m": {
"values": [
4,
2,
1
]
},
"warp_n": {
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
},
"warp_tile_n": {
"values": [
16
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"cshuffle",
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
false,
true
]
}
}
}

View File

@@ -103,6 +103,36 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[32, 32, 64],
],
},
"gfx1201": {
"fp16_fp16_fp16": [
[16, 16, 16],
],
},
}
# Supported warp tile combinations for different GPU architectures and data types
WARP_SUPPORTED_COMBINATIONS = {
"gfx90a": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx942": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx950": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx1201": [
[2, 4, 1],
[1, 8, 1],
[8, 1, 1],
[4, 2, 1],
],
}
# Unsupported trait combinations
@@ -155,9 +185,32 @@ def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) ->
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool:
def validate_warp_configuration(
warp_m: int,
warp_n: int,
warp_k: int,
gpu_name: str = None,
) -> bool:
"""Validate warp configuration."""
return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
if gpu_name is None:
gpu_name = get_gpu_name_by_id(0)
current_combination = [warp_m, warp_n, warp_k]
allowed_combinations = WARP_SUPPORTED_COMBINATIONS.get(gpu_name, {})
if not allowed_combinations:
# If GPU not recognized, try to be permissive but log warning
logging.warning(f"No warp_[m/n/k] combinations found for GPU: {gpu_name}")
return True
# Check if current combination is in the allowed list
if current_combination not in allowed_combinations:
error_msg = (
f"Invalid warp tile combination: {current_combination} not in allowed list. "
)
return False
return True
def validate_dimension_alignment(