[CK TILE ENGINE] Code changes to finding GPU id from TARGET (#3055)

* Reading gpuname from target for gemm in ck tile engine

* Reading gpuname from target for gemm preshuffle in ck tile engine

* Reading gpuname from target for gemm preshuffle in ck tile engine

* Get GPU changes for GEMM Muti D in TILE ENGINE

* Addressing errors for gpu name in cktileengine
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2025-10-20 11:02:18 -05:00
committed by GitHub
parent f18b79f328
commit 9f77061094
12 changed files with 59 additions and 149 deletions

View File

@@ -57,6 +57,7 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
--kernel_name "gemm_${datatype}_${layout}_${trait}_${tile_config}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
--gpu_target "${GEMM_GPU_TARGETS_INDIVIDUAL}"
DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
@@ -163,7 +164,8 @@ function(build_individual_gemm_targets datatype layout)
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--list_kernels")
--gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL}
--list_kernels ")
# First, just list the kernels (fast operation)
message(STATUS " Listing kernel configurations...")
@@ -173,7 +175,8 @@ function(build_individual_gemm_targets datatype layout)
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--list_kernels
--gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL}
--list_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output

View File

@@ -7,10 +7,6 @@
Mappings and utility functions for kernel code generation.
"""
import subprocess
import re
from functools import lru_cache
DATA_TYPE_MAP = {
"fp32": "float",
"fp16": "ck_tile::half_t",
@@ -212,31 +208,3 @@ def element_size(data_type: str) -> float:
if data_type not in ELEMENT_SIZE_MAP:
raise ValueError(f"Unsupported data type: {data_type}")
return ELEMENT_SIZE_MAP[data_type]
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
@lru_cache(maxsize=1)
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
try:
output = subprocess.check_output(
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
)
if matches := GPU_NAME_PATTERN.finditer(output):
gpu_list = [m.group(1) for m in matches]
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
return ""
except subprocess.CalledProcessError as e:
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
except FileNotFoundError:
print("ROCm tools not installed (requires rocminfo)")
except subprocess.TimeoutExpired:
print("GPU query timeout (5s)")
except Exception as e:
print(f"GPU detection error: {str(e)}")
return ""

View File

@@ -15,8 +15,9 @@ logging.basicConfig(level=logging.INFO)
class GemmKernelBuilder:
def __init__(self, working_path, datatype, layout, config_json=None):
def __init__(self, working_path, gpu_target, datatype, layout, config_json=None):
self.working_path = Path(working_path)
self.gpu_target = gpu_target
self.datatype = datatype
self.layout = layout
self.config_json = config_json
@@ -231,6 +232,7 @@ class GemmKernelBuilder:
b_datatype,
c_datatype,
pipeline,
self.gpu_target,
)
def _generate_trait_combinations(self):
@@ -822,6 +824,11 @@ def main():
description="GEMM kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument(
"--gpu_target",
required=True,
help="GPU target architecture",
)
parser.add_argument(
"--datatype",
required=True,
@@ -861,7 +868,7 @@ def main():
# Create builder
builder = GemmKernelBuilder(
args.working_path, args.datatype, args.layout, args.config_json
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
)
if args.list_kernels:

View File

@@ -7,7 +7,6 @@ from validation_utils import (
is_tile_config_valid,
is_trait_combination_valid,
validate_warp_tile_combination,
get_gpu_name_by_id,
)
@@ -16,8 +15,7 @@ def test_warp_tile_validation():
print("Testing warp tile combination validation...")
# Get GPU name
gpu_name = get_gpu_name_by_id(0)
print(f"Detected GPU: {gpu_name}")
gpu_name = "gfx90a"
# Test cases for fp16
test_cases = [

View File

@@ -7,9 +7,6 @@ Validation utilities for GEMM kernel generation.
Extracted from tile_engine_develop for consistency.
"""
import subprocess
import re
from functools import lru_cache
import logging
from typing import Tuple, List
@@ -152,34 +149,6 @@ def element_size(data_type: str) -> float:
return ELEMENT_SIZE_MAP[data_type]
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
@lru_cache(maxsize=1)
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
try:
output = subprocess.check_output(
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
)
if matches := GPU_NAME_PATTERN.finditer(output):
gpu_list = [m.group(1) for m in matches]
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
return ""
except subprocess.CalledProcessError as e:
logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
except FileNotFoundError:
logging.debug("ROCm tools not installed (requires rocminfo)")
except subprocess.TimeoutExpired:
logging.debug("GPU query timeout (5s)")
except Exception as e:
logging.debug(f"GPU detection error: {str(e)}")
return ""
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
"""Check if a trait combination is valid."""
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
@@ -189,11 +158,9 @@ def validate_warp_configuration(
warp_m: int,
warp_n: int,
warp_k: int,
gpu_name: str = None,
gpu_name: str,
) -> bool:
"""Validate warp configuration."""
if gpu_name is None:
gpu_name = get_gpu_name_by_id(0)
current_combination = [warp_m, warp_n, warp_k]
@@ -274,11 +241,9 @@ def validate_warp_tile_combination(
a_datatype: str,
b_datatype: str,
c_datatype: str,
gpu_name: str = None,
gpu_name: str,
) -> Tuple[bool, str]:
"""Validate warp tile combination against GPU-specific supported combinations."""
if gpu_name is None:
gpu_name = get_gpu_name_by_id(0)
# Construct the key for looking up supported combinations
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
@@ -325,6 +290,7 @@ def is_tile_config_valid(
b_datatype: str,
c_datatype: str,
pipeline: str,
gpu_target: str,
trait_name: str = None,
) -> bool:
"""
@@ -348,7 +314,7 @@ def is_tile_config_valid(
return False
# Validate warp configuration
if not validate_warp_configuration(warp_m, warp_n, warp_k):
if not validate_warp_configuration(warp_m, warp_n, warp_k, gpu_target):
logging.debug(
f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})"
)
@@ -384,7 +350,13 @@ def is_tile_config_valid(
# Validate warp tile combination
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype
warp_tile_m,
warp_tile_n,
warp_tile_k,
a_datatype,
b_datatype,
c_datatype,
gpu_target,
)
if not warp_tile_valid:
logging.debug(f"Warp tile validation failed: {warp_tile_error}")

View File

@@ -43,6 +43,7 @@ function(build_gemm_multi_d_for_datatype_layout datatype layout)
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json ${json_blob}
--list_blobs
--gpu_target ${GEMM_GPU_TARGETS}
RESULT_VARIABLE ret
)
if(NOT ret EQUAL 0)
@@ -62,6 +63,7 @@ function(build_gemm_multi_d_for_datatype_layout datatype layout)
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json "${json_blob}"
--gen_blobs
--gpu_target ${GEMM_GPU_TARGETS}
COMMENT "Generating GEMM Multi D instance sources for ${datatype} ${layout}"
)
add_custom_target(gemm_multi_d_gen_${datatype}_${layout} DEPENDS ${codegen_blobs})

View File

@@ -7,10 +7,6 @@
Mappings and utility functions for kernel code generation.
"""
import subprocess
import re
from functools import lru_cache
DATA_TYPE_MAP = {
"fp32": "float",
"fp16": "ck_tile::half_t",
@@ -198,31 +194,3 @@ def element_size(data_type: str) -> float:
if data_type not in ELEMENT_SIZE_MAP:
raise ValueError(f"Unsupported data type: {data_type}")
return ELEMENT_SIZE_MAP[data_type]
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
@lru_cache(maxsize=1)
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
try:
output = subprocess.check_output(
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
)
if matches := GPU_NAME_PATTERN.finditer(output):
gpu_list = [m.group(1) for m in matches]
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
return ""
except subprocess.CalledProcessError as e:
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
except FileNotFoundError:
print("ROCm tools not installed (requires rocminfo)")
except subprocess.TimeoutExpired:
print("GPU query timeout (5s)")
except Exception as e:
print(f"GPU detection error: {str(e)}")
return ""

View File

@@ -22,7 +22,6 @@ from gemm_multi_d_codegen_utils import (
warp_tile_supported_combinations,
trait_unsupported_combinations,
element_size,
get_gpu_name_by_id,
)
import logging
@@ -40,6 +39,8 @@ class GemmMultiDCodeGenerator:
self.output_dir = Path(args.working_path)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.gpu_target = args.gpu_target
if user_provided_config is not None:
self.config = user_provided_config
else:
@@ -261,7 +262,7 @@ class GemmMultiDCodeGenerator:
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
gpu_name = get_gpu_name_by_id(0)
gpu_name = self.gpu_target
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
if not gpu_warp_tile_key:
@@ -713,6 +714,11 @@ if __name__ == "__main__":
required=False,
help="The path where all the blobs are going to be generated",
)
parser.add_argument(
"--gpu_target",
required=True,
help="GPU target architecture",
)
parser.add_argument(
"-j",
"--config_json",

View File

@@ -57,6 +57,7 @@ function(create_individual_gemm_preshuffle_target datatype layout trait tile_con
--kernel_name "gemm_preshuffle_${datatype}_${layout}_${trait}_${tile_config}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
--gpu_target "${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}"
DEPENDS ${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
@@ -160,9 +161,11 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
# First, just list the kernels (fast operation)
message(STATUS " Listing kernel configurations...")
message(STATUS " GPU Targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py
--working_path ${working_path}
--gpu_target ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}

View File

@@ -7,9 +7,6 @@ Validation utilities for GEMM kernel generation.
Extracted from tile_engine_develop for consistency.
"""
import subprocess
import re
from functools import lru_cache
import logging
from typing import Tuple, List
@@ -123,34 +120,6 @@ def element_size(data_type: str) -> float:
return ELEMENT_SIZE_MAP[data_type]
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
@lru_cache(maxsize=1)
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
try:
output = subprocess.check_output(
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
)
if matches := GPU_NAME_PATTERN.finditer(output):
gpu_list = [m.group(1) for m in matches]
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
return ""
except subprocess.CalledProcessError as e:
logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
except FileNotFoundError:
logging.debug("ROCm tools not installed (requires rocminfo)")
except subprocess.TimeoutExpired:
logging.debug("GPU query timeout (5s)")
except Exception as e:
logging.debug(f"GPU detection error: {str(e)}")
return ""
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
"""Check if a trait combination is valid."""
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
@@ -225,11 +194,9 @@ def validate_warp_tile_combination(
a_datatype: str,
b_datatype: str,
c_datatype: str,
gpu_name: str = None,
gpu_name: str,
) -> Tuple[bool, str]:
"""Validate warp tile combination against GPU-specific supported combinations."""
if gpu_name is None:
gpu_name = get_gpu_name_by_id(0)
# Construct the key for looking up supported combinations
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
@@ -276,6 +243,7 @@ def is_tile_config_valid(
b_datatype: str,
c_datatype: str,
pipeline: str,
gpu_target: str,
trait_name: str = None,
) -> bool:
"""
@@ -335,7 +303,13 @@ def is_tile_config_valid(
# Validate warp tile combination
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype
warp_tile_m,
warp_tile_n,
warp_tile_k,
a_datatype,
b_datatype,
c_datatype,
gpu_target,
)
if not warp_tile_valid:
logging.debug(f"Warp tile validation failed: {warp_tile_error}")

View File

@@ -17,8 +17,9 @@ from commons.validation_utils import (
class GemmPreshuffleKernelBuilder:
def __init__(self, working_path, datatype, layout, config_json=None):
def __init__(self, working_path, gpu_target, datatype, layout, config_json=None):
self.working_path = Path(working_path)
self.gpu_target = gpu_target
self.datatype = datatype
self.layout = layout
self.config_json = config_json
@@ -294,6 +295,7 @@ class GemmPreshuffleKernelBuilder:
b_datatype,
c_datatype,
pipeline,
self.gpu_target,
)
def _generate_kernel_instance(
@@ -711,6 +713,11 @@ def main():
description="GEMM kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument(
"--gpu_target",
required=True,
help="GPU target architecture",
)
parser.add_argument(
"--datatype",
required=True,
@@ -765,7 +772,7 @@ def main():
# Create builder
builder = GemmPreshuffleKernelBuilder(
args.working_path, args.datatype, args.layout, args.config_json
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
)
if args.list_kernels: