[rocm-libraries] ROCm/rocm-libraries#5438 (commit 7000562)

[CK_TILE] Normalize gpu_target before LDS_SIZE_MAP lookup
 (#5438)

GPU targets passed with feature suffixes (e.g. `gfx950:xnack+`) were
falling through to `DEFAULT_LDS_SIZE` instead of matching their entry in
`LDS_SIZE_MAP`, potentially causing incorrect tile acceptance/rejection.

## Changes

- **`gemm_validation_utils.py`**: Strip everything after `:` from
`gpu_target` before the `LDS_SIZE_MAP` lookup; use the normalized base
arch name in the error message as well.

```python
# Before
hw_lds_size = LDS_SIZE_MAP.get(gpu_target, DEFAULT_LDS_SIZE)

# After
base_gpu_target = gpu_target.split(":")[0] if gpu_target else gpu_target
hw_lds_size = LDS_SIZE_MAP.get(base_gpu_target, DEFAULT_LDS_SIZE)
```
This commit is contained in:
Sami Remes
2026-05-29 16:33:15 +00:00
committed by assistant-librarian[bot]
parent 8bd8094012
commit b619c374eb
2 changed files with 54 additions and 24 deletions

View File

@@ -154,6 +154,21 @@ class GemmKernelBuilder:
persistent,
) = trait_combo
# Skip if this tile config is not valid for this specific pipeline
if not self._validate_tile_config(
tile_config["tile_m"],
tile_config["tile_n"],
tile_config["tile_k"],
tile_config["warp_m"],
tile_config["warp_n"],
tile_config["warp_k"],
tile_config["warp_tile_m"],
tile_config["warp_tile_n"],
tile_config["warp_tile_k"],
pipeline,
):
continue
# Create kernel name with proper boolean capitalization
kernel_name = f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
@@ -240,15 +255,12 @@ class GemmKernelBuilder:
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
# Generate all combinations
default_pipeline = ""
if self.kernel_name_prefix == "gemm_universal":
default_pipeline = "compv4"
elif self.kernel_name_prefix == "gemm_multi_d":
default_pipeline = "compv4"
elif self.kernel_name_prefix == "gemm_preshuffle":
default_pipeline = "preshufflev2"
elif self.kernel_name_prefix == "grouped_gemm":
default_pipeline = "compv4"
pipelines = self.config["trait_config"].get("pipeline", {}).get("values", [])
if not pipelines:
if self.kernel_name_prefix == "gemm_preshuffle":
pipelines = ["preshufflev2"]
else:
pipelines = ["compv4"]
configs = []
for tile_m in tile_m_values:
@@ -260,18 +272,21 @@ class GemmKernelBuilder:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for warp_tile_k in warp_tile_k_values:
# Validate configuration
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
default_pipeline,
# Accept tile if valid for any pipeline
if any(
self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
pipeline,
)
for pipeline in pipelines
):
configs.append(
{

View File

@@ -282,6 +282,16 @@ def validate_dimension_alignment(
return len(alignment_issues) == 0, alignment_issues
LDS_SIZE_MAP = {
"gfx90a": 2**16, # 64KB
"gfx942": 2**16, # 64KB
"gfx950": 160 * 1024, # 160KB
"gfx1201": 2**16, # 64KB
}
DEFAULT_LDS_SIZE = 2**16 # 64KB
def validate_lds_capacity(
tile_m: int,
tile_n: int,
@@ -289,18 +299,23 @@ def validate_lds_capacity(
a_datatype: str,
b_datatype: str,
pipeline: str,
gpu_target: str = "",
) -> Tuple[bool, str]:
"""Validate LDS capacity requirements."""
matrix_a_size = (tile_m * tile_k) * element_size(a_datatype)
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
total_tile_in_lds = matrix_a_size + matrix_b_size
max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16
base_gpu_target = gpu_target.split(":")[0] if gpu_target else gpu_target
hw_lds_size = LDS_SIZE_MAP.get(base_gpu_target, DEFAULT_LDS_SIZE)
double_buffer = pipeline in ["preshufflev2", "compv4"]
max_tile_size = hw_lds_size // 2 if double_buffer else hw_lds_size
if total_tile_in_lds > max_tile_size:
error_msg = (
f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB) "
f"[{base_gpu_target}, {'double' if double_buffer else 'single'} buffer]. Breakdown:\n"
f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
)
@@ -461,7 +476,7 @@ def is_tile_config_valid(
# Validate LDS capacity
lds_valid, lds_error = validate_lds_capacity(
tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline
tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline, gpu_target
)
if not lds_valid:
logging.debug(f"LDS validation failed: {lds_error}")