mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 08:48:45 +00:00
[rocm-libraries] ROCm/rocm-libraries#7465 (commit 81f1cf0)
[CK TILE] Increase default kPerXdl for grouped convolution instances (#7465) ## Summary Increases the default `kPerXdl` used in CK Tile grouped convolution instance generation for forward, backward-data, and backward-weight operations. ### Changes in `generate_instances.py` - **Larger default `kPerXdl` for all fp16/bf16 tile sizes**: `get_k_mfma()` now returns `32` for `m/nPerXdl = 16` and `16` for `m/nPerXdl = 32`. - **Cap `kPerXdl` to `kPerBlock`**: All three parsers (`parse_fwd_instances`, `parse_bwd_weight_instances`, `parse_bwd_data_instances`) now clamp the computed value with `min(..., k_per_block)` to prevent generating invalid instances where `kPerXdl > kPerBlock`. ### Expected impact Higher `kPerXdl` increases the number of MFMA instructions issued per warp per inner-loop iteration, improving arithmetic intensity and reducing pipeline stall overhead for memory-bound shapes.
This commit is contained in:
@@ -118,9 +118,9 @@ def get_k_mfma(dtype, m_per_xdl, n_per_xdl):
|
||||
return 4
|
||||
else:
|
||||
if m_per_xdl == 32:
|
||||
return 8
|
||||
else:
|
||||
return 16
|
||||
else:
|
||||
return 32
|
||||
|
||||
|
||||
def check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector):
|
||||
@@ -444,7 +444,7 @@ def parse_fwd_instances(instances, problem_name):
|
||||
warp_size = 64
|
||||
k_warp = int(block_size / (warp_size * m_warp * n_warp))
|
||||
dtype = get_dtype(problem_name)
|
||||
k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl))
|
||||
k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block)
|
||||
|
||||
if split_image:
|
||||
print(
|
||||
@@ -656,7 +656,7 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
k_warp = int(block_size / (warp_size * m_warp * n_warp))
|
||||
dtype = get_dtype(problem_name)
|
||||
|
||||
k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl))
|
||||
k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block)
|
||||
|
||||
if not check_vectors(
|
||||
a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector
|
||||
@@ -792,7 +792,7 @@ def parse_bwd_data_instances(instances, problem_name):
|
||||
k_warp = int(block_size / (warp_size * m_warp * n_warp))
|
||||
dtype = get_dtype(problem_name)
|
||||
|
||||
k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl))
|
||||
k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block)
|
||||
|
||||
if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False:
|
||||
print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.")
|
||||
|
||||
Reference in New Issue
Block a user