[rocm-libraries] ROCm/rocm-libraries#5114 (commit 59b8cb5)

[CK][CK Tile] Improvements for grouped conv fwd tile
 profiling (#5114)

## Motivation

Improve profiling for grouped convolution forward for better comparison
between CK and CK Tile
## Technical Details

- Include preprocessing time for ck tile
- Add flush cache for conv fwd profiler
- Switch configs to builder reflect
- Add KPerXdl deduce
- Add non-grouped ported instances

## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

pass

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

AICK-786
This commit is contained in:
Bartłomiej Kocot
2026-03-11 22:39:20 +00:00
committed by assistant-librarian[bot]
parent c1f2d8166d
commit 2169367735
24 changed files with 2375 additions and 1874 deletions

View File

@@ -108,6 +108,33 @@ def check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector)
return False
return True
def parse_instance_string(instance_string):
"""Parse instance string, treating Seq(...) as a single parameter."""
params = []
current_param = ""
paren_depth = 0
for char in instance_string:
if char == '(':
paren_depth += 1
current_param += char
elif char == ')':
paren_depth -= 1
current_param += char
elif char == ',' and paren_depth == 0:
# Only split on comma if we're not inside parentheses
params.append(current_param.strip())
current_param = ""
else:
current_param += char
# Add the last parameter
if current_param.strip():
params.append(current_param.strip())
return params
def generate_calls_inc(instances, problem_name, direction, filter_pattern):
generate_dir = Path(__file__).resolve().parent
output_dir = Path(f"{generate_dir}/instances/{direction}")
@@ -168,69 +195,80 @@ def parse_fwd_instances(instances, problem_name):
for instance_id, instance in enumerate(instances):
if instance.find("#") != -1 or instance.find(";") != -1:
continue
instance_args_list = instance[instance.find("<") + 1 : instance.find(">")]
args = instance_args_list.split(", ")
start = instance.index('<') + 1
end = instance.rindex('>')
params_str = instance[start:end]
args = parse_instance_string(params_str)
block_size = int(args[0])
m_per_block = int(args[1])
n_per_block = int(args[2])
k_per_block = int(args[3])
spec = args[4]
m_per_xdl = int(args[5])
n_per_xdl = int(args[6])
m_xdl_per_wave = int(args[7])
n_xdl_per_wave = int(args[8])
a_scalar_per_vector = int(args[9])
b_scalar_per_vector = int(args[10])
c_scalar_per_vector = int(args[11])
if len(args) == 15:
num_groups_to_merge = int(args[14])
elif len(args) != 16 and len(args) != 14:
raise RuntimeError("wrong number of parameters")
is_v3_instance = instance.find("Xdl_CShuffle_V3") != -1
split_image = instance.find("Large_Tensor") != -1
if is_v3_instance:
spec = args[14]
block_size = int(args[16])
m_per_block = int(args[17])
n_per_block = int(args[18])
k_per_block = int(args[19])
k1 = int(args[20])
m_per_xdl = int(args[22])
n_per_xdl = int(args[23])
m_xdl_per_wave = int(args[24])
n_xdl_per_wave = int(args[25])
a_scalar_per_vector = int(args[30])
b_scalar_per_vector = int(args[37])
c_scalar_per_vector = int(args[43])
scheduler = args[44]
pipeline_version = args[45]
direct_load = args[48] == "true"
num_groups_to_merge = int(args[49])
else:
num_groups_to_merge = 1
split_image = instance.find("Large") != -1
double_smem_buffer = instance.find("BlkGemmPipelineVersion: v4") != -1
spec = args[14]
block_size = int(args[17])
m_per_block = int(args[18])
n_per_block = int(args[19])
k_per_block = int(args[20])
k1 = int(args[21])
m_per_xdl = int(args[23])
n_per_xdl = int(args[24])
m_xdl_per_wave = int(args[25])
n_xdl_per_wave = int(args[26])
a_scalar_per_vector = int(args[31])
b_scalar_per_vector = int(args[38])
c_scalar_per_vector = int(args[44])
scheduler = "Intrawave"
pipeline_version = "v1"
direct_load = 0
num_groups_to_merge = 0 if split_image else int(args[48])
double_smem_buffer = pipeline_version == "v4"
num_wave_groups = 1
scheduler = (
"Intrawave" if instance.find("BlkGemmPipelineScheduler") == -1 else args[14]
)
pipeline_version = (
"v1" if instance.find("BlkGemmPipelineVersion") == -1 else args[15]
)
# Replace pipeline if Direct Load
if instance.find("DirectLoad") != -1:
if instance.find("BlkGemmPipelineVersion: v1") != -1:
if direct_load:
if pipeline_version == "v1":
pipeline_version = "ASYNC_V1"
elif instance.find("BlkGemmPipelineVersion: v4") != -1:
elif pipeline_version == "v4":
pipeline_version = "ASYNC_V4"
else:
raise RuntimeError("not supported pipeline for direct load")
raise RuntimeError(f"{pipeline_version} not supported pipeline for direct load")
else:
pipeline_version = f"""V{pipeline_version[-1:]}"""
pipeline_version = pipeline_version.upper()
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
warp_size = 64
k_warp = int(block_size / (warp_size * m_warp * n_warp))
dtype = get_dtype(problem_name)
# TODO: Make it more flexible
# k_per_xdl = f"ck_tile::get_k_warp_tile<{dtype}, {m_per_xdl}>()"
if dtype == "float":
if m_per_xdl == 32:
if instance.find("BlkGemmPipelineVersion") == -1:
k_per_xdl = 4
else:
# Increase for universal gemm
k_per_xdl = 8
else:
k_per_xdl = 8
else:
if m_per_xdl == 32:
k_per_xdl = 16
else:
k_per_xdl = 32
k_per_xdl = min(k_per_xdl, k_per_block)
k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl))
if split_image:
print(f"Skipping instance {instance_id} with split_image since it's not supported yet.")
continue
if pipeline_version == "V5":
print(f"Skipping instance {instance_id} with V5 since it's not supported yet.")
continue
if pipeline_version == "ASYNC_V4":
print(f"Skipping instance {instance_id} with ASYNC_V4 since it's not supported yet.")
continue
conv = ConvInstanceTemplateParams(
spec,
@@ -250,32 +288,6 @@ def parse_fwd_instances(instances, problem_name):
convs.append(conv)
return convs
def parse_instance_string(instance_string):
"""Parse instance string, treating Seq(...) as a single parameter."""
params = []
current_param = ""
paren_depth = 0
for char in instance_string:
if char == '(':
paren_depth += 1
current_param += char
elif char == ')':
paren_depth -= 1
current_param += char
elif char == ',' and paren_depth == 0:
# Only split on comma if we're not inside parentheses
params.append(current_param.strip())
current_param = ""
else:
current_param += char
# Add the last parameter
if current_param.strip():
params.append(current_param.strip())
return params
def parse_bwd_weight_instances(instances, problem_name):
convs = []