mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 08:48:45 +00:00
[rocm-libraries] ROCm/rocm-libraries#5438 (commit 7000562)
[CK_TILE] Normalize gpu_target before LDS_SIZE_MAP lookup (#5438) GPU targets passed with feature suffixes (e.g. `gfx950:xnack+`) were falling through to `DEFAULT_LDS_SIZE` instead of matching their entry in `LDS_SIZE_MAP`, potentially causing incorrect tile acceptance/rejection. ## Changes - **`gemm_validation_utils.py`**: Strip everything after `:` from `gpu_target` before the `LDS_SIZE_MAP` lookup; use the normalized base arch name in the error message as well. ```python # Before hw_lds_size = LDS_SIZE_MAP.get(gpu_target, DEFAULT_LDS_SIZE) # After base_gpu_target = gpu_target.split(":")[0] if gpu_target else gpu_target hw_lds_size = LDS_SIZE_MAP.get(base_gpu_target, DEFAULT_LDS_SIZE) ```
This commit is contained in:
committed by
assistant-librarian[bot]
parent
8bd8094012
commit
b619c374eb
@@ -154,6 +154,21 @@ class GemmKernelBuilder:
|
||||
persistent,
|
||||
) = trait_combo
|
||||
|
||||
# Skip if this tile config is not valid for this specific pipeline
|
||||
if not self._validate_tile_config(
|
||||
tile_config["tile_m"],
|
||||
tile_config["tile_n"],
|
||||
tile_config["tile_k"],
|
||||
tile_config["warp_m"],
|
||||
tile_config["warp_n"],
|
||||
tile_config["warp_k"],
|
||||
tile_config["warp_tile_m"],
|
||||
tile_config["warp_tile_n"],
|
||||
tile_config["warp_tile_k"],
|
||||
pipeline,
|
||||
):
|
||||
continue
|
||||
|
||||
# Create kernel name with proper boolean capitalization
|
||||
kernel_name = f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
|
||||
|
||||
@@ -240,15 +255,12 @@ class GemmKernelBuilder:
|
||||
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
|
||||
|
||||
# Generate all combinations
|
||||
default_pipeline = ""
|
||||
if self.kernel_name_prefix == "gemm_universal":
|
||||
default_pipeline = "compv4"
|
||||
elif self.kernel_name_prefix == "gemm_multi_d":
|
||||
default_pipeline = "compv4"
|
||||
elif self.kernel_name_prefix == "gemm_preshuffle":
|
||||
default_pipeline = "preshufflev2"
|
||||
elif self.kernel_name_prefix == "grouped_gemm":
|
||||
default_pipeline = "compv4"
|
||||
pipelines = self.config["trait_config"].get("pipeline", {}).get("values", [])
|
||||
if not pipelines:
|
||||
if self.kernel_name_prefix == "gemm_preshuffle":
|
||||
pipelines = ["preshufflev2"]
|
||||
else:
|
||||
pipelines = ["compv4"]
|
||||
|
||||
configs = []
|
||||
for tile_m in tile_m_values:
|
||||
@@ -260,18 +272,21 @@ class GemmKernelBuilder:
|
||||
for warp_tile_m in warp_tile_m_values:
|
||||
for warp_tile_n in warp_tile_n_values:
|
||||
for warp_tile_k in warp_tile_k_values:
|
||||
# Validate configuration
|
||||
if self._validate_tile_config(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
default_pipeline,
|
||||
# Accept tile if valid for any pipeline
|
||||
if any(
|
||||
self._validate_tile_config(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
pipeline,
|
||||
)
|
||||
for pipeline in pipelines
|
||||
):
|
||||
configs.append(
|
||||
{
|
||||
|
||||
@@ -282,6 +282,16 @@ def validate_dimension_alignment(
|
||||
return len(alignment_issues) == 0, alignment_issues
|
||||
|
||||
|
||||
LDS_SIZE_MAP = {
|
||||
"gfx90a": 2**16, # 64KB
|
||||
"gfx942": 2**16, # 64KB
|
||||
"gfx950": 160 * 1024, # 160KB
|
||||
"gfx1201": 2**16, # 64KB
|
||||
}
|
||||
|
||||
DEFAULT_LDS_SIZE = 2**16 # 64KB
|
||||
|
||||
|
||||
def validate_lds_capacity(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
@@ -289,18 +299,23 @@ def validate_lds_capacity(
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
pipeline: str,
|
||||
gpu_target: str = "",
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate LDS capacity requirements."""
|
||||
matrix_a_size = (tile_m * tile_k) * element_size(a_datatype)
|
||||
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16
|
||||
base_gpu_target = gpu_target.split(":")[0] if gpu_target else gpu_target
|
||||
hw_lds_size = LDS_SIZE_MAP.get(base_gpu_target, DEFAULT_LDS_SIZE)
|
||||
double_buffer = pipeline in ["preshufflev2", "compv4"]
|
||||
max_tile_size = hw_lds_size // 2 if double_buffer else hw_lds_size
|
||||
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
error_msg = (
|
||||
f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
|
||||
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
|
||||
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB) "
|
||||
f"[{base_gpu_target}, {'double' if double_buffer else 'single'} buffer]. Breakdown:\n"
|
||||
f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
|
||||
f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
|
||||
)
|
||||
@@ -461,7 +476,7 @@ def is_tile_config_valid(
|
||||
|
||||
# Validate LDS capacity
|
||||
lds_valid, lds_error = validate_lds_capacity(
|
||||
tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline
|
||||
tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline, gpu_target
|
||||
)
|
||||
if not lds_valid:
|
||||
logging.debug(f"LDS validation failed: {lds_error}")
|
||||
|
||||
Reference in New Issue
Block a user