diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 9eb396eea3..792c340ad7 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -154,6 +154,21 @@ class GemmKernelBuilder: persistent, ) = trait_combo + # Skip if this tile config is not valid for this specific pipeline + if not self._validate_tile_config( + tile_config["tile_m"], + tile_config["tile_n"], + tile_config["tile_k"], + tile_config["warp_m"], + tile_config["warp_n"], + tile_config["warp_k"], + tile_config["warp_tile_m"], + tile_config["warp_tile_n"], + tile_config["warp_tile_k"], + pipeline, + ): + continue + # Create kernel name with proper boolean capitalization kernel_name = f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" @@ -240,15 +255,12 @@ class GemmKernelBuilder: warp_tile_k_values = tile_config.get("warp_tile_k").get("values") # Generate all combinations - default_pipeline = "" - if self.kernel_name_prefix == "gemm_universal": - default_pipeline = "compv4" - elif self.kernel_name_prefix == "gemm_multi_d": - default_pipeline = "compv4" - elif self.kernel_name_prefix == "gemm_preshuffle": - default_pipeline = "preshufflev2" - elif self.kernel_name_prefix == "grouped_gemm": - default_pipeline = "compv4" + pipelines = self.config["trait_config"].get("pipeline", {}).get("values", []) + if not pipelines: + if self.kernel_name_prefix == "gemm_preshuffle": + pipelines = ["preshufflev2"] + else: + pipelines = ["compv4"] configs = [] for tile_m in tile_m_values: @@ -260,18 +272,21 @@ class GemmKernelBuilder: for warp_tile_m in warp_tile_m_values: for warp_tile_n in warp_tile_n_values: for warp_tile_k in warp_tile_k_values: - # Validate configuration - if self._validate_tile_config( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - default_pipeline, + # Accept tile if valid for any pipeline + if any( + self._validate_tile_config( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + pipeline, + ) + for pipeline in pipelines ): configs.append( { diff --git a/tile_engine/ops/gemm/gemm_validation_utils.py b/tile_engine/ops/gemm/gemm_validation_utils.py index aa3c04cf95..deb39845b7 100644 --- a/tile_engine/ops/gemm/gemm_validation_utils.py +++ b/tile_engine/ops/gemm/gemm_validation_utils.py @@ -282,6 +282,16 @@ def validate_dimension_alignment( return len(alignment_issues) == 0, alignment_issues +LDS_SIZE_MAP = { + "gfx90a": 2**16, # 64KB + "gfx942": 2**16, # 64KB + "gfx950": 160 * 1024, # 160KB + "gfx1201": 2**16, # 64KB +} + +DEFAULT_LDS_SIZE = 2**16 # 64KB + + def validate_lds_capacity( tile_m: int, tile_n: int, @@ -289,18 +299,23 @@ def validate_lds_capacity( a_datatype: str, b_datatype: str, pipeline: str, + gpu_target: str = "", ) -> Tuple[bool, str]: """Validate LDS capacity requirements.""" matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) total_tile_in_lds = matrix_a_size + matrix_b_size - max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16 + base_gpu_target = gpu_target.split(":")[0] if gpu_target else gpu_target + hw_lds_size = LDS_SIZE_MAP.get(base_gpu_target, DEFAULT_LDS_SIZE) + double_buffer = pipeline in ["preshufflev2", "compv4"] + max_tile_size = hw_lds_size // 2 if double_buffer else hw_lds_size if total_tile_in_lds > max_tile_size: error_msg = ( f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " - f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" + f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB) " + f"[{base_gpu_target}, {'double' if double_buffer else 'single'} buffer]. Breakdown:\n" f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B" ) @@ -461,7 +476,7 @@ def is_tile_config_valid( # Validate LDS capacity lds_valid, lds_error = validate_lds_capacity( - tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline + tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline, gpu_target ) if not lds_valid: logging.debug(f"LDS validation failed: {lds_error}")