Fixing python backward compatibility issue in benchmarking script.

This commit is contained in:
Vidyasagar Ananthan
2025-09-02 11:32:01 -04:00
committed by Aviral Goel
parent bab747b017
commit 0e322200e5
2 changed files with 6 additions and 4 deletions

View File

@@ -8,6 +8,7 @@ import multiprocessing
import concurrent.futures
from pathlib import Path
import logging
from typing import Optional
from validation_utils import is_tile_config_valid, is_trait_combination_valid
logging.basicConfig(level=logging.INFO)
@@ -325,7 +326,7 @@ class GemmKernelBuilder:
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
}
def _get_abc_layouts(self, layout_code: str | None = None):
def _get_abc_layouts(self, layout_code: Optional[str] = None):
"""
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
If layout_code is None, use self.layout.

View File

@@ -11,6 +11,7 @@ import subprocess
import re
from functools import lru_cache
import logging
from typing import Tuple, List
# Element size mapping for different data types
ELEMENT_SIZE_MAP = {
@@ -169,7 +170,7 @@ def validate_dimension_alignment(
warp_tile_m: int,
warp_tile_n: int,
warp_tile_k: int,
) -> tuple[bool, list[str]]:
) -> Tuple[bool, List[str]]:
"""Check if tile dimensions are properly aligned with warp dimensions."""
alignment_issues = []
@@ -196,7 +197,7 @@ def validate_lds_capacity(
a_datatype: str,
b_datatype: str,
pipeline: str,
) -> tuple[bool, str]:
) -> Tuple[bool, str]:
"""Validate LDS capacity requirements."""
matrix_a_size = (tile_m * tile_k) * element_size(a_datatype)
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
@@ -224,7 +225,7 @@ def validate_warp_tile_combination(
b_datatype: str,
c_datatype: str,
gpu_name: str = None,
) -> tuple[bool, str]:
) -> Tuple[bool, str]:
"""Validate warp tile combination against GPU-specific supported combinations."""
if gpu_name is None:
gpu_name = get_gpu_name_by_id(0)