From 7d7d293e5d5b84c6d4fe72c1e9cc2b421fbcfbdb Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Tue, 26 May 2026 16:43:05 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#6511 (commit 867bece) [CK_TILE] Adding steps in Stream-K Tile Engine ## Motivation This PR adds step functionality to the Stream-K instance generator in Tile Engine in order to quickly generate instance configurations within a certain max/min range. To complement this, the Stream-K Tile Engine validation file has been updated for more rigorous validation of generated instances. ## Technical Details - Added _generate_values helper to support min/max/step range-based tile config generation, matching Universal GEMM - Added validate_gemm, validate_whole_wg_cover_configuration, validate_cshuffle_epilogue_distribution, and other supporting functions to the Stream-K validation utils, aligning with the validation already present in the Universal GEMM ## Test Plan Tested using the generation in CK Tile Engine ## Test Result All instances were generated and validated correctly. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../gemm_streamk_instance_builder.py | 31 +- .../gemm_streamk_validation_utils.py | 327 +++++++++++++++++- 2 files changed, 351 insertions(+), 7 deletions(-) diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index 8fd422e6b8..1b3b9a7f26 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -127,6 +127,25 @@ class GemmKernelBuilder: # New format - generate combinations from individual parameter values tile_config = self.config["tile_config"] + if tile_config.get("tile_m").get("values") is None: + tile_config.get("tile_m")["values"] = self._generate_values( + tile_config.get("tile_m").get("min"), + tile_config.get("tile_m").get("max"), + tile_config.get("tile_m").get("step"), + ) + if tile_config.get("tile_n").get("values") is None: + tile_config.get("tile_n")["values"] = self._generate_values( + tile_config.get("tile_n").get("min"), + tile_config.get("tile_n").get("max"), + tile_config.get("tile_n").get("step"), + ) + if tile_config.get("tile_k").get("values") is None: + tile_config.get("tile_k")["values"] = self._generate_values( + tile_config.get("tile_k").get("min"), + tile_config.get("tile_k").get("max"), + tile_config.get("tile_k").get("step"), + ) + # Get all possible values for each parameter tile_m_values = tile_config.get("tile_m", {}).get("values", [256]) tile_n_values = tile_config.get("tile_n", {}).get("values", [256]) @@ -180,6 +199,15 @@ class GemmKernelBuilder: # Fallback to default return [] + def _generate_values(self, min_val, max_val, step): + """Generate a list of values from min to max with the given step""" + values = [] + val = min_val + while val <= max_val: + values.append(val) + val += step + return values + def _validate_tile_config( self, tile_m, @@ -191,7 +219,7 @@ class GemmKernelBuilder: warp_tile_m, warp_tile_n, warp_tile_k, - pipeline="mem", # Default pipeline for validation + pipeline="compv3", # Default pipeline for validation fast_mode=False, # Add fast mode option ): """Validate that tile configuration is reasonable""" @@ -239,6 +267,7 @@ class GemmKernelBuilder: b_datatype, c_datatype, pipeline, + self.layout ) def _generate_trait_combinations(self): diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py index d6c76c95b5..eccf5be108 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py @@ -26,6 +26,16 @@ ELEMENT_SIZE_MAP = { "fp64": 8, } +def get_warp_size_for_gpu(gpu_target: str) -> int: + """Get the warp size for a given GPU target. + + CDNA architectures (gfx9xx) use WAVE64 (64 threads per wavefront). + RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use WAVE32 (32 threads per wavefront). + """ + if gpu_target.startswith("gfx9"): + return 64 # CDNA - WAVE64 + return 32 # RDNA and others - WAVE32 + # Supported warp tile combinations for different GPU architectures and data types WARP_TILE_SUPPORTED_COMBINATIONS = { "gfx90a": { @@ -34,7 +44,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -42,7 +51,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], @@ -54,7 +62,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -62,7 +69,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], @@ -75,7 +81,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -83,7 +88,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [ @@ -333,6 +337,7 @@ def is_tile_config_valid( b_datatype: str, c_datatype: str, pipeline: str, + layout: str, trait_name: str = None, ) -> bool: """ @@ -390,6 +395,17 @@ def is_tile_config_valid( logging.debug(f"LDS validation failed: {lds_error}") return False + gemm_valid, gemm_error = validate_gemm( + tile_m, tile_n, tile_k, + warp_m, warp_n, warp_k, + warp_tile_m, warp_tile_n, warp_tile_k, + a_datatype, b_datatype, c_datatype, + pipeline, layout, + ) + if not gemm_valid: + logging.debug(f"GEMM validation failed: {gemm_error}") + return False + # Validate warp tile combination warp_tile_valid, warp_tile_error = validate_warp_tile_combination( warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype @@ -399,3 +415,302 @@ def is_tile_config_valid( return False return True + +def validate_whole_wg_cover_configuration( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + layout, + a_datatype, + b_datatype, + gpu_target: str = "gfx90a", +) -> Tuple[bool, str]: + # Validate whole workgroup cover configuration + + warp_size = get_warp_size_for_gpu(gpu_target) + NumWarps = warp_m * warp_n * warp_k + BlockSize = NumWarps * warp_size + + XPerTile = 0 + YPerTile = 0 + vector_load_size = 0 + + # A matrix validation + if layout[0] == "r": + vector_load_size = get_global_vector_load_size( + BlockSize, tile_k, a_datatype, tile_m, tile_k + ) + + XPerTile = tile_k + YPerTile = tile_m + + elif layout[0] == "c": + vector_load_size = get_global_vector_load_size( + BlockSize, tile_k, a_datatype, tile_m, tile_m + ) + + # Validate distribution + XPerTile = tile_k + YPerTile = tile_m + + wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( + XPerTile, YPerTile, BlockSize, vector_load_size, warp_size + ) + + if not wg_cover_core_valid: + logging.debug( + f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}" + ) + return False, wg_cover_core_error + + XPerTile = tile_m + YPerTile = tile_k + + wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( + XPerTile, YPerTile, BlockSize, vector_load_size, warp_size + ) + + if not wg_cover_core_valid: + logging.debug( + f"whole workgroup cover failed for Matrix A: {wg_cover_core_error}" + ) + return False, wg_cover_core_error + + # B matrix validation + if layout[1] == "r": + vector_load_size = get_global_vector_load_size( + BlockSize, tile_k, b_datatype, tile_n, tile_n + ) + + # Validate distribution + XPerTile = tile_k + YPerTile = tile_n + + wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( + XPerTile, YPerTile, BlockSize, vector_load_size, warp_size + ) + + if not wg_cover_core_valid: + logging.debug( + f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}" + ) + return False, wg_cover_core_error + + XPerTile = tile_n + YPerTile = tile_k + + elif layout[1] == "c": + XPerTile = tile_k + YPerTile = tile_n + + vector_load_size = get_global_vector_load_size( + BlockSize, tile_k, b_datatype, tile_n, tile_k + ) + + wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( + XPerTile, YPerTile, BlockSize, vector_load_size, warp_size + ) + if not wg_cover_core_valid: + logging.debug( + f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}" + ) + return False, wg_cover_core_error + + return True, "" + +def wg_cover_core_validation( + XPerTile: int, + YPerTile: int, + BlockSize: int, + vector_load_size: int, + warp_size: int, +) -> Tuple[bool, str]: + if XPerTile % vector_load_size != 0: + return False, "XPerTile is not divisible by vector_load_size" + + num_warps = BlockSize / warp_size + LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size) + + X1 = LargestVec if vector_load_size > LargestVec else vector_load_size + X0 = XPerTile / X1 + Y1 = warp_size // X0 + + if X0 * Y1 != warp_size: + return False, "X0 * Y1 != warp_size" + + return True, "" + + +def validate_cshuffle_epilogue_distribution( + tile_m: int, + tile_n: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_size: int, + c_datatype: str, +) -> Tuple[bool, str]: + """ + Validate that the CShuffleEpilogue tile distribution pattern is valid. + + This mirrors the static_assert in static_encoding_pattern.hpp: + static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!"); + + The CShuffleEpilogue creates a tile_distribution_encoding_pattern_2d + where: + - BlockSize = warp_m * warp_n * warp_k * warp_size + - YPerTile = MPerIterationShuffle (derived from tile_m / (warp_m * warp_tile_m / some_factor)) + - XPerTile = NPerIterationShuffle (derived from tile_n) + - VecSize = vector size based on element size (typically 8 for fp16) + + The key constraint is that X0 must evenly divide warp_size, where: + - X0 = min(warp_size, XPerTile / X1) + - X1 = min(VecSize, LargestVec) + - LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size) + """ + NumWarps = warp_m * warp_n * warp_k + BlockSize = NumWarps * warp_size + + elem_size = ELEMENT_SIZE_MAP.get(c_datatype, 2) + VecSize = 16 // elem_size + + XPerTile = tile_n + YPerTile = tile_m // warp_m + + if XPerTile <= 0 or YPerTile <= 0: + return False, f"Invalid tile dimensions: XPerTile={XPerTile}, YPerTile={YPerTile}" + + num_warps = BlockSize // warp_size + if num_warps * warp_size == 0: + return False, "Invalid BlockSize or warp_size" + + LargestVec = (XPerTile * YPerTile) // (num_warps * warp_size) + if LargestVec <= 0: + LargestVec = 1 + + X1 = min(VecSize, LargestVec) if LargestVec > 0 else VecSize + if X1 <= 0: + X1 = 1 + + X0 = min(warp_size, XPerTile // X1) if X1 > 0 else warp_size + + Y1 = warp_size // X0 if X0 > 0 else 0 + + if X0 * Y1 != warp_size: + return ( + False, + f"CShuffleEpilogue distribution invalid: X0({X0}) * Y1({Y1}) = {X0 * Y1} != warp_size({warp_size}). " + f"XPerTile={XPerTile}, YPerTile={YPerTile}, VecSize={VecSize}, BlockSize={BlockSize}" + ) + + return True, "" + + +def get_global_vector_load_size( + BlockSize: int, + KPerBlock: int, + DataType: str, + MNPerBlock: int, + XPerTile: int, +) -> int: + elements_per_thread = MNPerBlock * KPerBlock / BlockSize + PackedSize = 1 + + if ( + PackedSize == 2 + and XPerTile % (PackedSize * 32 / element_size(DataType)) == 0 + and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0 + ): + return PackedSize * 32 / element_size(DataType) + elif ( + XPerTile % (PackedSize * 16 / element_size(DataType)) == 0 + and elements_per_thread % (PackedSize * 16 / element_size(DataType)) == 0 + ): + return int(PackedSize * 16 / element_size(DataType)) + + elif ( + XPerTile % (PackedSize * 8 / element_size(DataType)) == 0 + and elements_per_thread % (PackedSize * 8 / element_size(DataType)) == 0 + ): + return int(PackedSize * 8 / element_size(DataType)) + elif ( + element_size(DataType) >= PackedSize * 4 + and XPerTile % (PackedSize * 4 / element_size(DataType)) == 0 + and elements_per_thread % (PackedSize * 4 / element_size(DataType)) == 0 + ): + return int(PackedSize * 4 / element_size(DataType)) + elif ( + element_size(DataType) >= PackedSize * 2 + and XPerTile % (PackedSize * 2 / element_size(DataType)) == 0 + and elements_per_thread % (PackedSize * 2 / element_size(DataType)) == 0 + ): + return int(PackedSize * 2 / element_size(DataType)) + else: + return PackedSize + +def validate_gemm( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + layout: str, + trait_name: str = None, +) -> bool: + # GEMM Validation + gpu_target = get_gpu_name_by_id(0) + warp_size = get_warp_size_for_gpu(gpu_target) + + # Validate whole workgroup cover configuration + whole_workgroup_cover_valid, whole_workgroup_cover_error = ( + validate_whole_wg_cover_configuration( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + layout, + a_datatype, + b_datatype, + gpu_target, + ) + ) + if not whole_workgroup_cover_valid: + logging.debug( + f"Whole workgroup cover configuration validation failed: {whole_workgroup_cover_error}" + ) + return False, whole_workgroup_cover_error + + # Validate CShuffleEpilogue distribution pattern (for cshuffle epilogue) + # This validation ensures the tile distribution pattern is valid for the output tile + cshuffle_valid, cshuffle_error = validate_cshuffle_epilogue_distribution( + tile_m, + tile_n, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_size, + c_datatype, + ) + if not cshuffle_valid: + logging.debug(f"CShuffleEpilogue validation failed: {cshuffle_error}") + return False, cshuffle_error + + return True, "" +