mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Add support for gfx12 in tile_engine for GEMM benchmarking (#2802)
* initial work on adding support of gfx12 in tile_engine for GEMM benchmarking
* add stage("Run TILE_ENGINE_GEMM Tests on gfx1201") to Jenkins config
* make tile_[m/n/k] validation arch dependent
[ROCm/composable_kernel commit: 592d73ad73]
This commit is contained in:
62
Jenkinsfile
vendored
62
Jenkinsfile
vendored
@@ -157,9 +157,9 @@ def getDockerImage(Map conf=[:]){
|
||||
image = getDockerImageName()
|
||||
echo "Using default docker: ${image}"
|
||||
}
|
||||
//Check if image exists
|
||||
//Check if image exists
|
||||
def retimage
|
||||
try
|
||||
try
|
||||
{
|
||||
echo "Pulling image: ${image}"
|
||||
retimage = docker.image("${image}")
|
||||
@@ -232,7 +232,7 @@ def cmake_build(Map conf=[:]){
|
||||
def setup_args = conf.get("setup_args","")
|
||||
// make sure all unit tests always run on develop branch
|
||||
def runAllUnitTests = (env.BRANCH_NAME == "develop") ? true : params.RUN_ALL_UNIT_TESTS
|
||||
|
||||
|
||||
if (prefixpath != "/usr/local"){
|
||||
setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} "
|
||||
}
|
||||
@@ -357,7 +357,7 @@ def cmake_build(Map conf=[:]){
|
||||
"build_cmd",
|
||||
"${build_envs} ninja -j${nt} ${config_targets}"
|
||||
)
|
||||
|
||||
|
||||
cmd = conf.get("cmd", """
|
||||
${setup_cmd}
|
||||
${build_cmd}
|
||||
@@ -449,7 +449,7 @@ def buildHipClangJob(Map conf=[:]){
|
||||
checkout scm
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts
|
||||
if ( params.BUILD_INSTANCES_ONLY ){
|
||||
dockerOpts = "--group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
@@ -515,7 +515,7 @@ def Build_CK(Map conf=[:]){
|
||||
checkout scm
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
@@ -719,7 +719,7 @@ def process_results(Map conf=[:]){
|
||||
def image = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm6.3"
|
||||
def prefixpath = "/opt/rocm"
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
@@ -956,20 +956,20 @@ pipeline {
|
||||
defaultValue: '',
|
||||
description: 'If you want to use a custom docker image, please specify it here (default: leave blank).')
|
||||
string(
|
||||
name: 'ROCMVERSION',
|
||||
name: 'ROCMVERSION',
|
||||
defaultValue: '6.4.1',
|
||||
description: 'Specify which ROCM version to use: 6.4.1 (default).')
|
||||
string(
|
||||
name: 'COMPILER_VERSION',
|
||||
defaultValue: '',
|
||||
name: 'COMPILER_VERSION',
|
||||
defaultValue: '',
|
||||
description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline, or leave blank (default).')
|
||||
string(
|
||||
name: 'COMPILER_COMMIT',
|
||||
defaultValue: '',
|
||||
name: 'COMPILER_COMMIT',
|
||||
defaultValue: '',
|
||||
description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit (default), or use some specific commit of llvm-project branch.')
|
||||
string(
|
||||
name: 'BUILD_COMPILER',
|
||||
defaultValue: '/opt/rocm/llvm/bin/clang++',
|
||||
name: 'BUILD_COMPILER',
|
||||
defaultValue: '/opt/rocm/llvm/bin/clang++',
|
||||
description: 'Build CK with /opt/rocm/bin/hipcc, /llvm-project/build/bin/clang++, or with /opt/rocm/llvm/bin/clang++ (default).')
|
||||
booleanParam(
|
||||
name: "RUN_FULL_QA",
|
||||
@@ -1448,6 +1448,36 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Run TILE_ENGINE_GEMM Tests on gfx1201")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx1201") }
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx1201" \
|
||||
-D GEMM_DATATYPE="fp16" \
|
||||
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-DGEMM_CONFIG_FILE=gfx120x_config.json \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j64 benchmark_gemm_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
--warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
ninja -j64 benchmark_gemm_fp16_rcr && \
|
||||
ninja -j64 benchmark_gemm_fp16_rrr && \
|
||||
ninja -j64 benchmark_gemm_fp16_crr && \
|
||||
ninja -j64 benchmark_gemm_fp16_ccr """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1591,7 +1621,7 @@ pipeline {
|
||||
agent{ label rocmnode("gfx942") }
|
||||
steps{
|
||||
script {
|
||||
def execute_args = params.NINJA_FTIME_TRACE ?
|
||||
def execute_args = params.NINJA_FTIME_TRACE ?
|
||||
""" cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
@@ -1600,7 +1630,7 @@ pipeline {
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """
|
||||
|
||||
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0")
|
||||
}
|
||||
cleanWs()
|
||||
|
||||
@@ -20,7 +20,7 @@ fi
|
||||
GPU_TARGETS="gfx908;gfx90a;gfx942"
|
||||
|
||||
if [ $# -ge 1 ]; then
|
||||
case "$1" in
|
||||
case "$1" in
|
||||
gfx*)
|
||||
GPU_TARGETS=$1
|
||||
shift 1
|
||||
@@ -38,7 +38,7 @@ fi
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm/ \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
|
||||
-D CMAKE_CXX_FLAGS="-ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
|
||||
-D CMAKE_CXX_FLAGS="-ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker -fbracket-depth=512" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_DEV=ON \
|
||||
-D GPU_TARGETS=$GPU_TARGETS \
|
||||
|
||||
@@ -13,38 +13,38 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
|
||||
message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
|
||||
return()
|
||||
endif()
|
||||
|
||||
|
||||
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
|
||||
# First split by underscore to get three groups
|
||||
string(REPLACE "_" ";" config_groups ${tile_config})
|
||||
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
|
||||
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
|
||||
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
|
||||
|
||||
|
||||
# Parse tile dimensions
|
||||
string(REPLACE "x" ";" tile_parts ${tile_dims})
|
||||
list(GET tile_parts 0 tile_m)
|
||||
list(GET tile_parts 1 tile_n)
|
||||
list(GET tile_parts 2 tile_k)
|
||||
|
||||
|
||||
# Parse warp dimensions
|
||||
string(REPLACE "x" ";" warp_parts ${warp_dims})
|
||||
list(GET warp_parts 0 warp_m)
|
||||
list(GET warp_parts 1 warp_n)
|
||||
list(GET warp_parts 2 warp_k)
|
||||
|
||||
|
||||
# Parse warp tile dimensions
|
||||
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
|
||||
list(GET warp_tile_parts 0 warp_tile_m)
|
||||
list(GET warp_tile_parts 1 warp_tile_n)
|
||||
list(GET warp_tile_parts 2 warp_tile_k)
|
||||
|
||||
|
||||
set(target_name "benchmark_gemm_${datatype}_${layout}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
|
||||
# Generate the single instance header for this kernel
|
||||
set(instance_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
|
||||
|
||||
|
||||
# Add custom command to generate the header file at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${instance_header}
|
||||
@@ -60,27 +60,27 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
|
||||
DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json}
|
||||
COMMENT "Generating ${instance_header}"
|
||||
)
|
||||
|
||||
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
add_executable(${target_name}
|
||||
${GEMM_SOURCE_DIR}/benchmark_gemm_single.cpp
|
||||
${instance_header}
|
||||
)
|
||||
|
||||
|
||||
# Set GPU architectures
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL})
|
||||
|
||||
|
||||
# Set compile definitions
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
GEMM_SINGLE_INSTANCE_HPP="${instance_header}"
|
||||
)
|
||||
|
||||
|
||||
# Include directories
|
||||
target_include_directories(${target_name} PRIVATE
|
||||
${GEMM_SOURCE_DIR}
|
||||
${working_path}
|
||||
)
|
||||
|
||||
|
||||
# Compile options
|
||||
target_compile_options(${target_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
@@ -88,19 +88,19 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
|
||||
--offload-compress
|
||||
-include ${instance_header}
|
||||
)
|
||||
|
||||
|
||||
# Add to collection targets
|
||||
add_dependencies(benchmark_gemm_all ${target_name})
|
||||
add_dependencies(benchmark_gemm_${datatype} ${target_name})
|
||||
add_dependencies(benchmark_gemm_${layout} ${target_name})
|
||||
add_dependencies(benchmark_gemm_${datatype}_${layout} ${target_name})
|
||||
|
||||
|
||||
# Add to trait-specific targets
|
||||
string(REPLACE "_" ";" trait_parts ${trait})
|
||||
list(GET trait_parts 0 pipeline)
|
||||
list(GET trait_parts 1 epilogue)
|
||||
list(GET trait_parts 2 scheduler)
|
||||
|
||||
|
||||
add_dependencies(benchmark_gemm_${pipeline} ${target_name})
|
||||
add_dependencies(benchmark_gemm_${epilogue} ${target_name})
|
||||
add_dependencies(benchmark_gemm_${scheduler} ${target_name})
|
||||
@@ -109,13 +109,13 @@ endfunction()
|
||||
# Function to build individual GEMM targets
|
||||
function(build_individual_gemm_targets datatype layout)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
|
||||
# Choose config file
|
||||
# Priority order:
|
||||
# 1. Environment variable GEMM_CONFIG_FILE
|
||||
# 2. CMake variable GEMM_CONFIG_FILE
|
||||
# 2. CMake variable GEMM_CONFIG_FILE
|
||||
# 3. Default based on layout
|
||||
|
||||
|
||||
# Check environment variable first
|
||||
if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "")
|
||||
set(config_filename "$ENV{GEMM_CONFIG_FILE}")
|
||||
@@ -130,12 +130,12 @@ function(build_individual_gemm_targets datatype layout)
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
message(STATUS " Using default config for layout ${layout}")
|
||||
endif()
|
||||
|
||||
|
||||
# Check if config file exists
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(FATAL_ERROR "Config file not found: ${json_blob}")
|
||||
endif()
|
||||
|
||||
|
||||
# Determine number of workers for parallel generation
|
||||
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
@@ -147,17 +147,24 @@ function(build_individual_gemm_targets datatype layout)
|
||||
set(num_workers 8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# Generate individual kernel files using parallel version
|
||||
message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
|
||||
message(STATUS " Working path: ${working_path}")
|
||||
message(STATUS " Config file: ${json_blob}")
|
||||
message(STATUS " Python executable: ${Python3_EXECUTABLE}")
|
||||
message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py")
|
||||
|
||||
|
||||
# Create working directory first
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
|
||||
message(STATUS "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--list_kernels")
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(STATUS " Listing kernel configurations...")
|
||||
execute_process(
|
||||
@@ -172,11 +179,11 @@ function(build_individual_gemm_targets datatype layout)
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
|
||||
endif()
|
||||
|
||||
|
||||
# Read kernel count
|
||||
if(EXISTS ${working_path}/gemm_kernel_count.txt)
|
||||
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
|
||||
@@ -185,7 +192,7 @@ function(build_individual_gemm_targets datatype layout)
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel count file not found")
|
||||
endif()
|
||||
|
||||
|
||||
# Read kernel list and create targets
|
||||
if(EXISTS ${working_path}/gemm_kernel_list.txt)
|
||||
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
|
||||
@@ -195,7 +202,7 @@ function(build_individual_gemm_targets datatype layout)
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
|
||||
# Create individual target
|
||||
create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
endforeach()
|
||||
@@ -210,9 +217,9 @@ message(STATUS "GEMM_DATATYPE: ${GEMM_DATATYPE}")
|
||||
message(STATUS "GEMM_LAYOUT: ${GEMM_LAYOUT}")
|
||||
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942, and gfx950
|
||||
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
|
||||
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
@@ -223,13 +230,13 @@ endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
# Enable parallel compilation optimizations
|
||||
# Set up job pools for better parallel compilation control
|
||||
set_property(GLOBAL PROPERTY JOB_POOLS
|
||||
set_property(GLOBAL PROPERTY JOB_POOLS
|
||||
compile_heavy=4 # Limit heavy compilations to prevent OOM
|
||||
compile_normal=16 # Allow more parallel normal compilations
|
||||
)
|
||||
|
||||
@@ -179,6 +179,11 @@ warp_tile_supported_combinations = {
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
"gfx1201": {
|
||||
"fp16_fp16_fp16": [
|
||||
[16, 16, 16],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# To Do: remove some unsupported combinations
|
||||
|
||||
102
tile_engine/ops/gemm/configs/gfx120x_config.json
Normal file
102
tile_engine/ops/gemm/configs/gfx120x_config.json
Normal file
@@ -0,0 +1,102 @@
|
||||
{
|
||||
"problem": {
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256,
|
||||
128,
|
||||
64
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
256,
|
||||
128,
|
||||
64
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
256,
|
||||
128,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle",
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -103,6 +103,36 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
"gfx1201": {
|
||||
"fp16_fp16_fp16": [
|
||||
[16, 16, 16],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Supported warp tile combinations for different GPU architectures and data types
|
||||
WARP_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx942": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx950": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx1201": [
|
||||
[2, 4, 1],
|
||||
[1, 8, 1],
|
||||
[8, 1, 1],
|
||||
[4, 2, 1],
|
||||
],
|
||||
}
|
||||
|
||||
# Unsupported trait combinations
|
||||
@@ -155,9 +185,32 @@ def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) ->
|
||||
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
|
||||
|
||||
|
||||
def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool:
|
||||
def validate_warp_configuration(
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
gpu_name: str = None,
|
||||
) -> bool:
|
||||
"""Validate warp configuration."""
|
||||
return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
|
||||
if gpu_name is None:
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
current_combination = [warp_m, warp_n, warp_k]
|
||||
|
||||
allowed_combinations = WARP_SUPPORTED_COMBINATIONS.get(gpu_name, {})
|
||||
if not allowed_combinations:
|
||||
# If GPU not recognized, try to be permissive but log warning
|
||||
logging.warning(f"No warp_[m/n/k] combinations found for GPU: {gpu_name}")
|
||||
return True
|
||||
|
||||
# Check if current combination is in the allowed list
|
||||
if current_combination not in allowed_combinations:
|
||||
error_msg = (
|
||||
f"Invalid warp tile combination: {current_combination} not in allowed list. "
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_dimension_alignment(
|
||||
|
||||
Reference in New Issue
Block a user