mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Update pre-commit to fixed versions, run remod for ck_tile (#2895)
* Fix ruff linter errors * Fix remod dos2unix command * Clang format * Ignore utility in remod * Run remod * Specify clang-format version in pre-commit * Specify ruff version * Include PoolKernelArgs in reference_pool * Add calculate_total_elements to reference batched contraction * Fix calculate_total_elements declaration * Refactor remod pre-commit hook * Fix Aquant tests --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -170,11 +170,11 @@ warp_tile_supported_combinations = {
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"fp8_bf8_fp16": [
|
||||
"fp8_bf8_fp16": [
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_fp8_fp16": [
|
||||
"bf8_fp8_fp16": [
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
|
||||
@@ -107,32 +107,32 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
"fp16_fp16_fp16": [
|
||||
[16, 16, 16],
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Supported warp tile combinations for different GPU architectures and data types
|
||||
WARP_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx942": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx950": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx1201": [
|
||||
[2, 4, 1],
|
||||
[1, 8, 1],
|
||||
[8, 1, 1],
|
||||
[2, 4, 1],
|
||||
[1, 8, 1],
|
||||
[8, 1, 1],
|
||||
[4, 2, 1],
|
||||
],
|
||||
],
|
||||
}
|
||||
|
||||
# Unsupported trait combinations
|
||||
@@ -186,14 +186,14 @@ def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) ->
|
||||
|
||||
|
||||
def validate_warp_configuration(
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
gpu_name: str = None,
|
||||
) -> bool:
|
||||
"""Validate warp configuration."""
|
||||
if gpu_name is None:
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
current_combination = [warp_m, warp_n, warp_k]
|
||||
|
||||
@@ -205,11 +205,8 @@ def validate_warp_configuration(
|
||||
|
||||
# Check if current combination is in the allowed list
|
||||
if current_combination not in allowed_combinations:
|
||||
error_msg = (
|
||||
f"Invalid warp tile combination: {current_combination} not in allowed list. "
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user