mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[rocm-libraries] ROCm/rocm-libraries#5114 (commit 59b8cb5)
[CK][CK Tile] Improvements for grouped conv fwd tile profiling (#5114) ## Motivation Improve profiling for grouped convolution forward for better comparison between CK and CK Tile ## Technical Details - Include preprocessing time for ck tile - Add flush cache for conv fwd profiler - Switch configs to builder reflect - Add KPerXdl deduce - Add non-grouped ported instances ## Test Plan test_grouped_convnd_fwd_tile ## Test Result pass ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-786
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c1f2d8166d
commit
2169367735
@@ -108,6 +108,33 @@ def check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector)
|
||||
return False
|
||||
return True
|
||||
|
||||
def parse_instance_string(instance_string):
|
||||
"""Parse instance string, treating Seq(...) as a single parameter."""
|
||||
params = []
|
||||
current_param = ""
|
||||
paren_depth = 0
|
||||
|
||||
for char in instance_string:
|
||||
if char == '(':
|
||||
paren_depth += 1
|
||||
current_param += char
|
||||
elif char == ')':
|
||||
paren_depth -= 1
|
||||
current_param += char
|
||||
elif char == ',' and paren_depth == 0:
|
||||
# Only split on comma if we're not inside parentheses
|
||||
params.append(current_param.strip())
|
||||
current_param = ""
|
||||
else:
|
||||
current_param += char
|
||||
|
||||
# Add the last parameter
|
||||
if current_param.strip():
|
||||
params.append(current_param.strip())
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def generate_calls_inc(instances, problem_name, direction, filter_pattern):
|
||||
generate_dir = Path(__file__).resolve().parent
|
||||
output_dir = Path(f"{generate_dir}/instances/{direction}")
|
||||
@@ -168,69 +195,80 @@ def parse_fwd_instances(instances, problem_name):
|
||||
for instance_id, instance in enumerate(instances):
|
||||
if instance.find("#") != -1 or instance.find(";") != -1:
|
||||
continue
|
||||
instance_args_list = instance[instance.find("<") + 1 : instance.find(">")]
|
||||
args = instance_args_list.split(", ")
|
||||
start = instance.index('<') + 1
|
||||
end = instance.rindex('>')
|
||||
params_str = instance[start:end]
|
||||
args = parse_instance_string(params_str)
|
||||
|
||||
block_size = int(args[0])
|
||||
m_per_block = int(args[1])
|
||||
n_per_block = int(args[2])
|
||||
k_per_block = int(args[3])
|
||||
spec = args[4]
|
||||
m_per_xdl = int(args[5])
|
||||
n_per_xdl = int(args[6])
|
||||
m_xdl_per_wave = int(args[7])
|
||||
n_xdl_per_wave = int(args[8])
|
||||
a_scalar_per_vector = int(args[9])
|
||||
b_scalar_per_vector = int(args[10])
|
||||
c_scalar_per_vector = int(args[11])
|
||||
if len(args) == 15:
|
||||
num_groups_to_merge = int(args[14])
|
||||
elif len(args) != 16 and len(args) != 14:
|
||||
raise RuntimeError("wrong number of parameters")
|
||||
is_v3_instance = instance.find("Xdl_CShuffle_V3") != -1
|
||||
split_image = instance.find("Large_Tensor") != -1
|
||||
|
||||
if is_v3_instance:
|
||||
spec = args[14]
|
||||
block_size = int(args[16])
|
||||
m_per_block = int(args[17])
|
||||
n_per_block = int(args[18])
|
||||
k_per_block = int(args[19])
|
||||
k1 = int(args[20])
|
||||
m_per_xdl = int(args[22])
|
||||
n_per_xdl = int(args[23])
|
||||
m_xdl_per_wave = int(args[24])
|
||||
n_xdl_per_wave = int(args[25])
|
||||
a_scalar_per_vector = int(args[30])
|
||||
b_scalar_per_vector = int(args[37])
|
||||
c_scalar_per_vector = int(args[43])
|
||||
scheduler = args[44]
|
||||
pipeline_version = args[45]
|
||||
direct_load = args[48] == "true"
|
||||
num_groups_to_merge = int(args[49])
|
||||
else:
|
||||
num_groups_to_merge = 1
|
||||
split_image = instance.find("Large") != -1
|
||||
double_smem_buffer = instance.find("BlkGemmPipelineVersion: v4") != -1
|
||||
spec = args[14]
|
||||
block_size = int(args[17])
|
||||
m_per_block = int(args[18])
|
||||
n_per_block = int(args[19])
|
||||
k_per_block = int(args[20])
|
||||
k1 = int(args[21])
|
||||
m_per_xdl = int(args[23])
|
||||
n_per_xdl = int(args[24])
|
||||
m_xdl_per_wave = int(args[25])
|
||||
n_xdl_per_wave = int(args[26])
|
||||
a_scalar_per_vector = int(args[31])
|
||||
b_scalar_per_vector = int(args[38])
|
||||
c_scalar_per_vector = int(args[44])
|
||||
scheduler = "Intrawave"
|
||||
pipeline_version = "v1"
|
||||
direct_load = 0
|
||||
num_groups_to_merge = 0 if split_image else int(args[48])
|
||||
|
||||
double_smem_buffer = pipeline_version == "v4"
|
||||
num_wave_groups = 1
|
||||
scheduler = (
|
||||
"Intrawave" if instance.find("BlkGemmPipelineScheduler") == -1 else args[14]
|
||||
)
|
||||
pipeline_version = (
|
||||
"v1" if instance.find("BlkGemmPipelineVersion") == -1 else args[15]
|
||||
)
|
||||
# Replace pipeline if Direct Load
|
||||
if instance.find("DirectLoad") != -1:
|
||||
if instance.find("BlkGemmPipelineVersion: v1") != -1:
|
||||
if direct_load:
|
||||
if pipeline_version == "v1":
|
||||
pipeline_version = "ASYNC_V1"
|
||||
elif instance.find("BlkGemmPipelineVersion: v4") != -1:
|
||||
elif pipeline_version == "v4":
|
||||
pipeline_version = "ASYNC_V4"
|
||||
else:
|
||||
raise RuntimeError("not supported pipeline for direct load")
|
||||
raise RuntimeError(f"{pipeline_version} not supported pipeline for direct load")
|
||||
else:
|
||||
pipeline_version = f"""V{pipeline_version[-1:]}"""
|
||||
pipeline_version = pipeline_version.upper()
|
||||
|
||||
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
|
||||
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
|
||||
warp_size = 64
|
||||
k_warp = int(block_size / (warp_size * m_warp * n_warp))
|
||||
dtype = get_dtype(problem_name)
|
||||
# TODO: Make it more flexible
|
||||
# k_per_xdl = f"ck_tile::get_k_warp_tile<{dtype}, {m_per_xdl}>()"
|
||||
if dtype == "float":
|
||||
if m_per_xdl == 32:
|
||||
if instance.find("BlkGemmPipelineVersion") == -1:
|
||||
k_per_xdl = 4
|
||||
else:
|
||||
# Increase for universal gemm
|
||||
k_per_xdl = 8
|
||||
else:
|
||||
k_per_xdl = 8
|
||||
else:
|
||||
if m_per_xdl == 32:
|
||||
k_per_xdl = 16
|
||||
else:
|
||||
k_per_xdl = 32
|
||||
k_per_xdl = min(k_per_xdl, k_per_block)
|
||||
k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl))
|
||||
|
||||
if split_image:
|
||||
print(f"Skipping instance {instance_id} with split_image since it's not supported yet.")
|
||||
continue
|
||||
if pipeline_version == "V5":
|
||||
print(f"Skipping instance {instance_id} with V5 since it's not supported yet.")
|
||||
continue
|
||||
if pipeline_version == "ASYNC_V4":
|
||||
print(f"Skipping instance {instance_id} with ASYNC_V4 since it's not supported yet.")
|
||||
continue
|
||||
|
||||
conv = ConvInstanceTemplateParams(
|
||||
spec,
|
||||
@@ -250,32 +288,6 @@ def parse_fwd_instances(instances, problem_name):
|
||||
convs.append(conv)
|
||||
return convs
|
||||
|
||||
def parse_instance_string(instance_string):
|
||||
"""Parse instance string, treating Seq(...) as a single parameter."""
|
||||
params = []
|
||||
current_param = ""
|
||||
paren_depth = 0
|
||||
|
||||
for char in instance_string:
|
||||
if char == '(':
|
||||
paren_depth += 1
|
||||
current_param += char
|
||||
elif char == ')':
|
||||
paren_depth -= 1
|
||||
current_param += char
|
||||
elif char == ',' and paren_depth == 0:
|
||||
# Only split on comma if we're not inside parentheses
|
||||
params.append(current_param.strip())
|
||||
current_param = ""
|
||||
else:
|
||||
current_param += char
|
||||
|
||||
# Add the last parameter
|
||||
if current_param.strip():
|
||||
params.append(current_param.strip())
|
||||
|
||||
return params
|
||||
|
||||
def parse_bwd_weight_instances(instances, problem_name):
|
||||
convs = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user