mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK TILE ENGINE] Code changes to finding GPU id from TARGET (#3055)
* Reading gpuname from target for gemm in ck tile engine * Reading gpuname from target for gemm preshuffle in ck tile engine * Reading gpuname from target for gemm preshuffle in ck tile engine * Get GPU changes for GEMM Muti D in TILE ENGINE * Addressing errors for gpu name in cktileengine
This commit is contained in:
committed by
GitHub
parent
f18b79f328
commit
9f77061094
@@ -57,6 +57,7 @@ function(create_individual_gemm_target datatype layout trait tile_config config_
|
||||
--kernel_name "gemm_${datatype}_${layout}_${trait}_${tile_config}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
--gpu_target "${GEMM_GPU_TARGETS_INDIVIDUAL}"
|
||||
DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json}
|
||||
COMMENT "Generating ${instance_header}"
|
||||
)
|
||||
@@ -163,7 +164,8 @@ function(build_individual_gemm_targets datatype layout)
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--list_kernels")
|
||||
--gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels ")
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(STATUS " Listing kernel configurations...")
|
||||
@@ -173,7 +175,8 @@ function(build_individual_gemm_targets datatype layout)
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--list_kernels
|
||||
--gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
|
||||
@@ -7,10 +7,6 @@
|
||||
Mappings and utility functions for kernel code generation.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
DATA_TYPE_MAP = {
|
||||
"fp32": "float",
|
||||
"fp16": "ck_tile::half_t",
|
||||
@@ -212,31 +208,3 @@ def element_size(data_type: str) -> float:
|
||||
if data_type not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
|
||||
|
||||
return ""
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
|
||||
except FileNotFoundError:
|
||||
print("ROCm tools not installed (requires rocminfo)")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("GPU query timeout (5s)")
|
||||
except Exception as e:
|
||||
print(f"GPU detection error: {str(e)}")
|
||||
|
||||
return ""
|
||||
|
||||
@@ -15,8 +15,9 @@ logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class GemmKernelBuilder:
|
||||
def __init__(self, working_path, datatype, layout, config_json=None):
|
||||
def __init__(self, working_path, gpu_target, datatype, layout, config_json=None):
|
||||
self.working_path = Path(working_path)
|
||||
self.gpu_target = gpu_target
|
||||
self.datatype = datatype
|
||||
self.layout = layout
|
||||
self.config_json = config_json
|
||||
@@ -231,6 +232,7 @@ class GemmKernelBuilder:
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
pipeline,
|
||||
self.gpu_target,
|
||||
)
|
||||
|
||||
def _generate_trait_combinations(self):
|
||||
@@ -822,6 +824,11 @@ def main():
|
||||
description="GEMM kernel instance builder with parallel support"
|
||||
)
|
||||
parser.add_argument("--working_path", required=True, help="Working directory path")
|
||||
parser.add_argument(
|
||||
"--gpu_target",
|
||||
required=True,
|
||||
help="GPU target architecture",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
@@ -861,7 +868,7 @@ def main():
|
||||
|
||||
# Create builder
|
||||
builder = GemmKernelBuilder(
|
||||
args.working_path, args.datatype, args.layout, args.config_json
|
||||
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
|
||||
@@ -7,7 +7,6 @@ from validation_utils import (
|
||||
is_tile_config_valid,
|
||||
is_trait_combination_valid,
|
||||
validate_warp_tile_combination,
|
||||
get_gpu_name_by_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -16,8 +15,7 @@ def test_warp_tile_validation():
|
||||
print("Testing warp tile combination validation...")
|
||||
|
||||
# Get GPU name
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
print(f"Detected GPU: {gpu_name}")
|
||||
gpu_name = "gfx90a"
|
||||
|
||||
# Test cases for fp16
|
||||
test_cases = [
|
||||
|
||||
@@ -7,9 +7,6 @@ Validation utilities for GEMM kernel generation.
|
||||
Extracted from tile_engine_develop for consistency.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
|
||||
@@ -152,34 +149,6 @@ def element_size(data_type: str) -> float:
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
|
||||
|
||||
return ""
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
|
||||
except FileNotFoundError:
|
||||
logging.debug("ROCm tools not installed (requires rocminfo)")
|
||||
except subprocess.TimeoutExpired:
|
||||
logging.debug("GPU query timeout (5s)")
|
||||
except Exception as e:
|
||||
logging.debug(f"GPU detection error: {str(e)}")
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
|
||||
"""Check if a trait combination is valid."""
|
||||
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
|
||||
@@ -189,11 +158,9 @@ def validate_warp_configuration(
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
gpu_name: str = None,
|
||||
gpu_name: str,
|
||||
) -> bool:
|
||||
"""Validate warp configuration."""
|
||||
if gpu_name is None:
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
current_combination = [warp_m, warp_n, warp_k]
|
||||
|
||||
@@ -274,11 +241,9 @@ def validate_warp_tile_combination(
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
gpu_name: str = None,
|
||||
gpu_name: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate warp tile combination against GPU-specific supported combinations."""
|
||||
if gpu_name is None:
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
# Construct the key for looking up supported combinations
|
||||
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
|
||||
@@ -325,6 +290,7 @@ def is_tile_config_valid(
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
pipeline: str,
|
||||
gpu_target: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -348,7 +314,7 @@ def is_tile_config_valid(
|
||||
return False
|
||||
|
||||
# Validate warp configuration
|
||||
if not validate_warp_configuration(warp_m, warp_n, warp_k):
|
||||
if not validate_warp_configuration(warp_m, warp_n, warp_k, gpu_target):
|
||||
logging.debug(
|
||||
f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})"
|
||||
)
|
||||
@@ -384,7 +350,13 @@ def is_tile_config_valid(
|
||||
|
||||
# Validate warp tile combination
|
||||
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
|
||||
warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
gpu_target,
|
||||
)
|
||||
if not warp_tile_valid:
|
||||
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
|
||||
|
||||
@@ -43,6 +43,7 @@ function(build_gemm_multi_d_for_datatype_layout datatype layout)
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json ${json_blob}
|
||||
--list_blobs
|
||||
--gpu_target ${GEMM_GPU_TARGETS}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(NOT ret EQUAL 0)
|
||||
@@ -62,6 +63,7 @@ function(build_gemm_multi_d_for_datatype_layout datatype layout)
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json "${json_blob}"
|
||||
--gen_blobs
|
||||
--gpu_target ${GEMM_GPU_TARGETS}
|
||||
COMMENT "Generating GEMM Multi D instance sources for ${datatype} ${layout}"
|
||||
)
|
||||
add_custom_target(gemm_multi_d_gen_${datatype}_${layout} DEPENDS ${codegen_blobs})
|
||||
|
||||
@@ -7,10 +7,6 @@
|
||||
Mappings and utility functions for kernel code generation.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
DATA_TYPE_MAP = {
|
||||
"fp32": "float",
|
||||
"fp16": "ck_tile::half_t",
|
||||
@@ -198,31 +194,3 @@ def element_size(data_type: str) -> float:
|
||||
if data_type not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
|
||||
|
||||
return ""
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
|
||||
except FileNotFoundError:
|
||||
print("ROCm tools not installed (requires rocminfo)")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("GPU query timeout (5s)")
|
||||
except Exception as e:
|
||||
print(f"GPU detection error: {str(e)}")
|
||||
|
||||
return ""
|
||||
|
||||
@@ -22,7 +22,6 @@ from gemm_multi_d_codegen_utils import (
|
||||
warp_tile_supported_combinations,
|
||||
trait_unsupported_combinations,
|
||||
element_size,
|
||||
get_gpu_name_by_id,
|
||||
)
|
||||
import logging
|
||||
|
||||
@@ -40,6 +39,8 @@ class GemmMultiDCodeGenerator:
|
||||
self.output_dir = Path(args.working_path)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.gpu_target = args.gpu_target
|
||||
|
||||
if user_provided_config is not None:
|
||||
self.config = user_provided_config
|
||||
else:
|
||||
@@ -261,7 +262,7 @@ class GemmMultiDCodeGenerator:
|
||||
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
gpu_name = self.gpu_target
|
||||
|
||||
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
|
||||
if not gpu_warp_tile_key:
|
||||
@@ -713,6 +714,11 @@ if __name__ == "__main__":
|
||||
required=False,
|
||||
help="The path where all the blobs are going to be generated",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu_target",
|
||||
required=True,
|
||||
help="GPU target architecture",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--config_json",
|
||||
|
||||
@@ -57,6 +57,7 @@ function(create_individual_gemm_preshuffle_target datatype layout trait tile_con
|
||||
--kernel_name "gemm_preshuffle_${datatype}_${layout}_${trait}_${tile_config}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
--gpu_target "${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}"
|
||||
DEPENDS ${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_instance_builder.py ${config_json}
|
||||
COMMENT "Generating ${instance_header}"
|
||||
)
|
||||
@@ -160,9 +161,11 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(STATUS " Listing kernel configurations...")
|
||||
message(STATUS " GPU Targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--gpu_target ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
|
||||
@@ -7,9 +7,6 @@ Validation utilities for GEMM kernel generation.
|
||||
Extracted from tile_engine_develop for consistency.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
|
||||
@@ -123,34 +120,6 @@ def element_size(data_type: str) -> float:
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
|
||||
|
||||
return ""
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
|
||||
except FileNotFoundError:
|
||||
logging.debug("ROCm tools not installed (requires rocminfo)")
|
||||
except subprocess.TimeoutExpired:
|
||||
logging.debug("GPU query timeout (5s)")
|
||||
except Exception as e:
|
||||
logging.debug(f"GPU detection error: {str(e)}")
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
|
||||
"""Check if a trait combination is valid."""
|
||||
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
|
||||
@@ -225,11 +194,9 @@ def validate_warp_tile_combination(
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
gpu_name: str = None,
|
||||
gpu_name: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate warp tile combination against GPU-specific supported combinations."""
|
||||
if gpu_name is None:
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
# Construct the key for looking up supported combinations
|
||||
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
|
||||
@@ -276,6 +243,7 @@ def is_tile_config_valid(
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
pipeline: str,
|
||||
gpu_target: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -335,7 +303,13 @@ def is_tile_config_valid(
|
||||
|
||||
# Validate warp tile combination
|
||||
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
|
||||
warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
gpu_target,
|
||||
)
|
||||
if not warp_tile_valid:
|
||||
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
|
||||
|
||||
@@ -17,8 +17,9 @@ from commons.validation_utils import (
|
||||
|
||||
|
||||
class GemmPreshuffleKernelBuilder:
|
||||
def __init__(self, working_path, datatype, layout, config_json=None):
|
||||
def __init__(self, working_path, gpu_target, datatype, layout, config_json=None):
|
||||
self.working_path = Path(working_path)
|
||||
self.gpu_target = gpu_target
|
||||
self.datatype = datatype
|
||||
self.layout = layout
|
||||
self.config_json = config_json
|
||||
@@ -294,6 +295,7 @@ class GemmPreshuffleKernelBuilder:
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
pipeline,
|
||||
self.gpu_target,
|
||||
)
|
||||
|
||||
def _generate_kernel_instance(
|
||||
@@ -711,6 +713,11 @@ def main():
|
||||
description="GEMM kernel instance builder with parallel support"
|
||||
)
|
||||
parser.add_argument("--working_path", required=True, help="Working directory path")
|
||||
parser.add_argument(
|
||||
"--gpu_target",
|
||||
required=True,
|
||||
help="GPU target architecture",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
@@ -765,7 +772,7 @@ def main():
|
||||
|
||||
# Create builder
|
||||
builder = GemmPreshuffleKernelBuilder(
|
||||
args.working_path, args.datatype, args.layout, args.config_json
|
||||
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
|
||||
Reference in New Issue
Block a user