From 9cf49cd32275ed686b4302633da4b0c8dc87d3e8 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 21 May 2026 12:05:09 +0200 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7465 (commit 81f1cf0) [CK TILE] Increase default kPerXdl for grouped convolution instances (#7465) ## Summary Increases the default `kPerXdl` used in CK Tile grouped convolution instance generation for forward, backward-data, and backward-weight operations. ### Changes in `generate_instances.py` - **Larger default `kPerXdl` for all fp16/bf16 tile sizes**: `get_k_mfma()` now returns `32` for `m/nPerXdl = 16` and `16` for `m/nPerXdl = 32`. - **Cap `kPerXdl` to `kPerBlock`**: All three parsers (`parse_fwd_instances`, `parse_bwd_weight_instances`, `parse_bwd_data_instances`) now clamp the computed value with `min(..., k_per_block)` to prevent generating invalid instances where `kPerXdl > kPerBlock`. ### Expected impact Higher `kPerXdl` increases the number of MFMA instructions issued per warp per inner-loop iteration, improving arithmetic intensity and reducing pipeline stall overhead for memory-bound shapes. --- .../generate_instances.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/experimental/grouped_convolution_tile_instances/generate_instances.py b/experimental/grouped_convolution_tile_instances/generate_instances.py index 5620b62228..2651b16b3a 100755 --- a/experimental/grouped_convolution_tile_instances/generate_instances.py +++ b/experimental/grouped_convolution_tile_instances/generate_instances.py @@ -118,9 +118,9 @@ def get_k_mfma(dtype, m_per_xdl, n_per_xdl): return 4 else: if m_per_xdl == 32: - return 8 - else: return 16 + else: + return 32 def check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector): @@ -444,7 +444,7 @@ def parse_fwd_instances(instances, problem_name): warp_size = 64 k_warp = int(block_size / (warp_size * m_warp * n_warp)) dtype = get_dtype(problem_name) - k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)) + k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block) if split_image: print( @@ -656,7 +656,7 @@ def parse_bwd_weight_instances(instances, problem_name): k_warp = int(block_size / (warp_size * m_warp * n_warp)) dtype = get_dtype(problem_name) - k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)) + k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block) if not check_vectors( a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector @@ -792,7 +792,7 @@ def parse_bwd_data_instances(instances, problem_name): k_warp = int(block_size / (warp_size * m_warp * n_warp)) dtype = get_dtype(problem_name) - k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)) + k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block) if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False: print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.")