[rocm-libraries] ROCm/rocm-libraries#7465 (commit 81f1cf0)

[CK TILE] Increase default kPerXdl for grouped convolution instances  (#7465)

## Summary
Increases the default `kPerXdl` used in CK Tile grouped convolution
instance generation for forward, backward-data, and backward-weight
operations.
### Changes in `generate_instances.py`
- **Larger default `kPerXdl` for all fp16/bf16 tile sizes**:
`get_k_mfma()` now
returns `32` for `m/nPerXdl = 16` and `16` for `m/nPerXdl = 32`.
- **Cap `kPerXdl` to `kPerBlock`**: All three parsers
(`parse_fwd_instances`, `parse_bwd_weight_instances`,
`parse_bwd_data_instances`) now clamp the computed value with `min(...,
k_per_block)` to prevent generating invalid instances where `kPerXdl >
kPerBlock`.
### Expected impact
Higher `kPerXdl` increases the number of MFMA instructions issued per
warp per inner-loop iteration, improving arithmetic intensity and
reducing pipeline stall overhead for memory-bound shapes.
This commit is contained in:
jakpiase
2026-05-21 12:05:09 +02:00
committed by GitHub
parent e7798e9560
commit 9cf49cd322

View File

@@ -118,9 +118,9 @@ def get_k_mfma(dtype, m_per_xdl, n_per_xdl):
return 4
else:
if m_per_xdl == 32:
return 8
else:
return 16
else:
return 32
def check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector):
@@ -444,7 +444,7 @@ def parse_fwd_instances(instances, problem_name):
warp_size = 64
k_warp = int(block_size / (warp_size * m_warp * n_warp))
dtype = get_dtype(problem_name)
k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl))
k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block)
if split_image:
print(
@@ -656,7 +656,7 @@ def parse_bwd_weight_instances(instances, problem_name):
k_warp = int(block_size / (warp_size * m_warp * n_warp))
dtype = get_dtype(problem_name)
k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl))
k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block)
if not check_vectors(
a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector
@@ -792,7 +792,7 @@ def parse_bwd_data_instances(instances, problem_name):
k_warp = int(block_size / (warp_size * m_warp * n_warp))
dtype = get_dtype(problem_name)
k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl))
k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block)
if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False:
print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.")