From e0187eab4108a1988faa518e0c96fea5ccc44772 Mon Sep 17 00:00:00 2001 From: Vidyasagar Ananthan Date: Tue, 2 Sep 2025 11:32:01 -0400 Subject: [PATCH] Fixing python backward compatibility issue in benchmarking script. [ROCm/composable_kernel commit: 0e322200e5c959bddea8dda101197f685bc3c22c] --- tile_engine/ops/gemm/gemm_instance_builder.py | 3 ++- tile_engine/ops/gemm/validation_utils.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index d679be7b84..c2214da613 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -8,6 +8,7 @@ import multiprocessing import concurrent.futures from pathlib import Path import logging +from typing import Optional from validation_utils import is_tile_config_valid, is_trait_combination_valid logging.basicConfig(level=logging.INFO) @@ -325,7 +326,7 @@ class GemmKernelBuilder: "c": "ck_tile::tensor_layout::gemm::ColumnMajor", } - def _get_abc_layouts(self, layout_code: str | None = None): + def _get_abc_layouts(self, layout_code: Optional[str] = None): """ Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. If layout_code is None, use self.layout. diff --git a/tile_engine/ops/gemm/validation_utils.py b/tile_engine/ops/gemm/validation_utils.py index 4948fd5744..7367f2446d 100644 --- a/tile_engine/ops/gemm/validation_utils.py +++ b/tile_engine/ops/gemm/validation_utils.py @@ -11,6 +11,7 @@ import subprocess import re from functools import lru_cache import logging +from typing import Tuple, List # Element size mapping for different data types ELEMENT_SIZE_MAP = { @@ -169,7 +170,7 @@ def validate_dimension_alignment( warp_tile_m: int, warp_tile_n: int, warp_tile_k: int, -) -> tuple[bool, list[str]]: +) -> Tuple[bool, List[str]]: """Check if tile dimensions are properly aligned with warp dimensions.""" alignment_issues = [] @@ -196,7 +197,7 @@ def validate_lds_capacity( a_datatype: str, b_datatype: str, pipeline: str, -) -> tuple[bool, str]: +) -> Tuple[bool, str]: """Validate LDS capacity requirements.""" matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) @@ -224,7 +225,7 @@ def validate_warp_tile_combination( b_datatype: str, c_datatype: str, gpu_name: str = None, -) -> tuple[bool, str]: +) -> Tuple[bool, str]: """Validate warp tile combination against GPU-specific supported combinations.""" if gpu_name is None: gpu_name = get_gpu_name_by_id(0)