mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
python format
This commit is contained in:
@@ -159,25 +159,25 @@ def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
|
||||
|
||||
# To Do: add some more supported combinations
|
||||
warp_tile_supported_combinations = {
|
||||
"gfx90a": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]]
|
||||
},
|
||||
"gfx942": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
|
||||
},
|
||||
"gfx950": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
|
||||
}
|
||||
}
|
||||
"gfx90a": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]]
|
||||
},
|
||||
"gfx942": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
|
||||
},
|
||||
"gfx950": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
|
||||
}
|
||||
}
|
||||
|
||||
# To Do: remove some unsupported combinations
|
||||
trait_unsupported_combinations = {
|
||||
@@ -200,6 +200,7 @@ def element_size(data_type: str) -> float:
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
@@ -212,15 +213,15 @@ def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
|
||||
|
||||
arch_pattern = r'gfx\d{3,4}[a-z]?'
|
||||
match = re.search(arch_pattern, result.stdout.lower())
|
||||
return match.group() if match else ""
|
||||
|
||||
|
||||
except (FileNotFoundError, subprocess.CalledProcessError) as e:
|
||||
print(f"System Error: {str(e)}, when get the name of gpu:{gpu_id}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
print(f"Runtime Exception: {str(e)}, when get the name of gpu:{gpu_id}")
|
||||
print(
|
||||
f"Runtime Exception: {str(e)}, when get the name of gpu:{gpu_id}")
|
||||
return ""
|
||||
|
||||
|
||||
@@ -376,18 +376,20 @@ struct GemmKernel {{
|
||||
# Warp combination validation
|
||||
warp_tile_key = f"{self.config.problem.datatype_map['matrix_a']}_{self.config.problem.datatype_map['matrix_b']}_{self.config.problem.datatype_map['matrix_c']}"
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
|
||||
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
|
||||
if not gpu_warp_tile_key:
|
||||
logging.warning(f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.")
|
||||
logging.warning(
|
||||
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.")
|
||||
return True
|
||||
|
||||
|
||||
allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, [])
|
||||
if not allowed_combinations:
|
||||
logging.warning(f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.")
|
||||
logging.warning(
|
||||
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.")
|
||||
return True
|
||||
|
||||
|
||||
if current_combination not in allowed_combinations:
|
||||
logging.warning(
|
||||
f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. "
|
||||
|
||||
Reference in New Issue
Block a user