Files
jakpiase 9cf49cd322 [rocm-libraries] ROCm/rocm-libraries#7465 (commit 81f1cf0)
[CK TILE] Increase default kPerXdl for grouped convolution instances  (#7465)

## Summary
Increases the default `kPerXdl` used in CK Tile grouped convolution
instance generation for forward, backward-data, and backward-weight
operations.
### Changes in `generate_instances.py`
- **Larger default `kPerXdl` for all fp16/bf16 tile sizes**:
`get_k_mfma()` now
returns `32` for `m/nPerXdl = 16` and `16` for `m/nPerXdl = 32`.
- **Cap `kPerXdl` to `kPerBlock`**: All three parsers
(`parse_fwd_instances`, `parse_bwd_weight_instances`,
`parse_bwd_data_instances`) now clamp the computed value with `min(...,
k_per_block)` to prevent generating invalid instances where `kPerXdl >
kPerBlock`.
### Expected impact
Higher `kPerXdl` increases the number of MFMA instructions issued per
warp per inner-loop iteration, improving arithmetic intensity and
reducing pipeline stall overhead for memory-bound shapes.
2026-05-21 12:05:09 +02:00

1167 lines
43 KiB
Python
Executable File

# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import argparse
import shutil
from pathlib import Path
class ConvInstanceTemplateParams:
def __init__(
self,
specialization,
tile_size,
warps,
warp_tile,
double_smem_buffer,
num_wave_groups,
is_two_stage_instance,
pipeline_version,
scheduler,
scalar_per_vector,
num_groups_to_merge,
split_image,
explicit_gemm,
id,
streamk_enabled=False,
streamk_reduction_strategy=None,
streamk_persistent=False,
):
self.specialization = specialization
self.tile_size = tile_size
self.warps = warps
self.warp_tile = warp_tile
self.double_smem_buffer = double_smem_buffer
self.num_wave_groups = num_wave_groups
self.is_two_stage_instance = is_two_stage_instance
self.pipeline_version = pipeline_version
self.scheduler = scheduler
self.scalar_per_vector = scalar_per_vector
self.num_groups_to_merge = num_groups_to_merge
self.split_image = split_image
self.explicit_gemm = explicit_gemm
self.id = id
self.streamk_enabled = streamk_enabled
self.streamk_reduction_strategy = streamk_reduction_strategy
self.streamk_persistent = streamk_persistent
def get_optimizations(self):
explicit_gemm = "true" if self.explicit_gemm else "false"
split_image = "true" if self.split_image else "false"
num_groups_to_merge = str(self.num_groups_to_merge)
two_stage_instance = "true" if self.is_two_stage_instance else "false"
if self.streamk_enabled:
streamk_str = (
f"{{true, ckb::StreamKReductionStrategy::{self.streamk_reduction_strategy}, "
f"{'true' if self.streamk_persistent else 'false'}}}"
)
else:
streamk_str = "ckb::StreamKConfig::disabled()"
return (
f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, "
f".split_image = {split_image}, .explicit_gemm = {explicit_gemm}, "
f".two_stage = {two_stage_instance}, .streamk = {streamk_str}}}"
)
def get_specialization(self):
namespace = "ckb::TileConvSpecialization::"
if self.specialization == "Default" or self.specialization == "OddC":
return namespace + "DEFAULT"
if self.specialization == "Filter1x1Pad0":
return namespace + "FILTER_1X1_PAD0"
if self.specialization == "Filter1x1Stride1Pad0":
return namespace + "FILTER_1X1_STRIDE1_PAD0"
if self.specialization == "Filter3x3":
return namespace + "FILTER_3x3"
else:
raise RuntimeError("not supported specialization")
def get_thread_block(self):
return f"ckt::TileThreadBlock{{.tile_size = {{.m = {self.tile_size[0]}, .n = {self.tile_size[1]}, .k = {self.tile_size[2]}}}}}"
def get_block_gemm_desc(self):
double_smem_buffer = "true" if self.double_smem_buffer else "false"
scheduler = (
"INTRAWAVE" if self.scheduler.find("Intrawave") != -1 else "INTERWAVE"
)
return f"""ckt::TileBlockGemm{{
.warps = {{.m = {self.warps[0]}, .n = {self.warps[1]}, .k = {self.warps[2]}}},
.warp_tile = {{.m = {self.warp_tile[0]}, .n = {self.warp_tile[1]}, .k = {self.warp_tile[2]}}},
.double_smem_buffer = {double_smem_buffer},
.num_wave_groups = {self.num_wave_groups},
.pipeline_version = ckb::PipelineVersion::{self.pipeline_version},
.scheduler = ckb::PipelineScheduler::{scheduler}}}"""
def get_block_transfer(self):
return f"""ckt::TileTransfer{{.a_scalar_per_vector = {self.scalar_per_vector[0]},
.b_scalar_per_vector = {self.scalar_per_vector[1]}, .c_scalar_per_vector = {self.scalar_per_vector[2]}}}"""
def get_dtype(problem_name):
if problem_name.find("fp32") != -1:
return "float"
if problem_name.find("fp16") != -1:
return "ck_tile::half_t"
if problem_name.find("bf16") != -1:
return "ck_tile::bf16_t"
else:
raise RuntimeError("Cannot parse data type from problem name: " + problem_name)
def get_k_mfma(dtype, m_per_xdl, n_per_xdl):
if m_per_xdl != n_per_xdl:
raise RuntimeError("Not supported")
if dtype == "float":
if m_per_xdl == 32:
return 2
else:
return 4
else:
if m_per_xdl == 32:
return 16
else:
return 32
def check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector):
if a_scalar_per_vector != 1 and a_scalar_per_vector % 2 != 0:
return False
if b_scalar_per_vector != 1 and b_scalar_per_vector % 2 != 0:
return False
if c_scalar_per_vector != 1 and c_scalar_per_vector % 2 != 0:
return False
return True
def parse_instance_string(instance_string):
"""Parse instance string, treating Seq(...) as a single parameter."""
params = []
current_param = ""
paren_depth = 0
for char in instance_string:
if char == "(":
paren_depth += 1
current_param += char
elif char == ")":
paren_depth -= 1
current_param += char
elif char == "," and paren_depth == 0:
# Only split on comma if we're not inside parentheses
params.append(current_param.strip())
current_param = ""
else:
current_param += char
# Add the last parameter
if current_param.strip():
params.append(current_param.strip())
return params
def copy_includes(instances_path):
inc_dir = Path(__file__).resolve().parent
output_dir = Path(instances_path)
output_dir.mkdir(parents=True, exist_ok=True)
shutil.copy(f"{inc_dir}/include/instance_includes.inc", instances_path)
shutil.copy(f"{inc_dir}/include/instance_run.inc", instances_path)
shutil.copy(f"{inc_dir}/include/signatures.hpp", instances_path)
def generate_calls_inc(instances, problem_name, direction, filter_pattern):
generate_dir = Path(__file__).resolve().parent
output_dir = Path(f"{generate_dir}/instances/{direction}")
output_dir.mkdir(parents=True, exist_ok=True)
with open(
f"{generate_dir}/instances/{direction}/{problem_name}_calls.inc", "w"
) as f:
if problem_name.find(filter_pattern) == -1:
return
for instance in instances:
instance_name = problem_name + "_" + str(instance.id)
f.write(f"run_alg(run_{instance_name});\n")
def generate_defs_inc(instances, problem_name, signature, direction, filter_pattern):
generate_dir = Path(__file__).resolve().parent
with open(f"{generate_dir}/instances/{direction}/{problem_name}.inc", "w") as f:
if problem_name.find(filter_pattern) == -1:
return
for instance in instances:
instance_name = problem_name + "_" + str(instance.id)
f.write(
f"std::tuple<bool, float, std::string> run_{instance_name}(\n"
f" const ckt::Args<{signature}>& args,\n"
f" const ckt::Inputs<{signature}>& inputs,\n"
f" const ckt::Outputs<{signature}>& outputs,\n"
f" const ck_tile::stream_config& s_conf);\n"
)
def generate_conv_cpp(
instances,
problem_name,
config,
direction,
signature_name,
filter_pattern,
instances_path,
):
for instance in instances:
if problem_name.find(filter_pattern) == -1:
break
instance_name = problem_name + "_" + str(instance.id)
directory_path = Path(f"{instances_path}/{direction}/{config}")
directory_path.mkdir(parents=True, exist_ok=True)
parent_dir = Path(__file__).resolve().parent
template_file = "include/grouped_convolution_tile.cpp.in"
with open(
f"{parent_dir}/{template_file}",
"r",
) as f:
content = f.read()
content = content.replace("gen_signature", signature_name)
content = content.replace("gen_instance_name", instance_name)
content = content.replace(
"gen_specialization", instance.get_specialization()
)
content = content.replace("gen_thread_block", instance.get_thread_block())
content = content.replace(
"gen_block_gemm_desc", instance.get_block_gemm_desc()
)
content = content.replace(
"gen_block_transfer", instance.get_block_transfer()
)
content = content.replace("gen_optimizations", instance.get_optimizations())
with open(
f"{instances_path}/{direction}/{config}/{instance_name}.cpp",
"w",
) as f:
f.write(content)
# Maps ck_tile pipeline names (from GetPipelineName()) to builder PipelineVersion enum names.
PIPELINE_NAME_TO_VERSION = {
"BASIC_V1": "V1",
"MEMORY": "V2",
"COMPUTE_V3": "V3",
"COMPUTE_V4": "V4",
"COMPUTE_V5": "V5",
"COMPUTE_V6": "V6",
"BASIC_ASYNC_V1": "ASYNC_V1",
"COMPUTE_ASYNC": "ASYNC_V4",
}
# Maps ck_tile StreamKReductionStrategy int values (from static_cast<int> in instance string)
# to builder enum names. ck_tile enum: Atomic=0, Linear=1, Tree=2.
# Atomic=0 is omitted: it is not expected in generated instances. If encountered, .get()
# falls back to str(reduction_int) ("0"), which will cause a downstream build error.
STREAMK_REDUCTION_STRATEGY = {
1: "LINEAR",
2: "TREE",
}
def parse_native_bwd_weight_instance(args, instance_id, problem_name):
"""Parse a native CK Tile instance string (GroupedConvolutionBackwardWeightKernel<...>).
Fields (0-indexed after splitting on commas inside <>):
0: NDimSpatial, 1: ConvSpec, 2: InLayout, 3: WeiLayout, 4: DsLayout, 5: OutLayout,
6: VecA, 7: VecB, 8: VecC, 9: NumGroupsToMerge, 10: SplitImage, 11: ExplicitGemm,
12: MPerBlock, 13: NPerBlock, 14: KPerBlock, 15: MWarp, 16: NWarp, 17: KWarp,
18: MWarpTile, 19: NWarpTile, 20: KWarpTile, 21: ADataType, 22: BDataType,
23: PipelineName, 24: Scheduler, 25: DoubleSmemBuffer, 26: NumWaveGroups,
27: AccDataType, 28: EDataType, 29: DsDataType, 30: CDEElementwiseOp,
31: IsStreamK, [32: ReductionStrategy, 33: PersistentDP]
"""
spec = args[1]
tile_size = [int(args[12]), int(args[13]), int(args[14])]
warps = [int(args[15]), int(args[16]), int(args[17])]
warp_tile = [int(args[18]), int(args[19]), int(args[20])]
pipeline_name = args[23]
if pipeline_name not in PIPELINE_NAME_TO_VERSION:
raise RuntimeError(
f"Unknown pipeline name '{pipeline_name}' in native instance {instance_id}"
)
pipeline_version = PIPELINE_NAME_TO_VERSION[pipeline_name]
scheduler = args[24]
double_smem_buffer = int(args[25]) != 0
num_wave_groups = int(args[26])
scalar_per_vector = [int(args[6]), int(args[7]), int(args[8])]
num_groups_to_merge = int(args[9])
split_image = int(args[10]) != 0
explicit_gemm = int(args[11]) != 0
is_streamk = int(args[31]) != 0
streamk_reduction_strategy = None
streamk_persistent = False
is_two_stage = get_dtype(problem_name) != "float" and scalar_per_vector[2] == 1
if is_streamk:
is_two_stage = False
reduction_int = int(args[32])
streamk_reduction_strategy = STREAMK_REDUCTION_STRATEGY.get(
reduction_int, str(reduction_int)
)
streamk_persistent = int(args[33]) != 0
return ConvInstanceTemplateParams(
spec,
tile_size,
warps,
warp_tile,
double_smem_buffer,
num_wave_groups,
is_two_stage,
pipeline_version,
scheduler,
scalar_per_vector,
num_groups_to_merge,
split_image,
explicit_gemm,
instance_id,
streamk_enabled=is_streamk,
streamk_reduction_strategy=streamk_reduction_strategy,
streamk_persistent=streamk_persistent,
)
def parse_native_fwd_instance(args, instance_id, problem_name):
"""Parse a native CK Tile forward conv instance string."""
raise NotImplementedError("Native forward instance parsing is not yet implemented.")
def parse_native_bwd_data_instance(args, instance_id, problem_name):
"""Parse a native CK Tile backward data instance string."""
raise NotImplementedError(
"Native backward data instance parsing is not yet implemented."
)
# Maps kernel type prefix to native parser function.
NATIVE_PARSERS = {
"GroupedConvolutionBackwardWeightKernel": parse_native_bwd_weight_instance,
"GroupedConvolutionForwardKernel": parse_native_fwd_instance,
"GroupedConvolutionBackwardDataKernel": parse_native_bwd_data_instance,
}
def try_parse_native_instance(instance, instance_id, problem_name):
"""Try to parse an instance line as a native CK Tile instance string.
Returns a ConvInstanceTemplateParams if the line matches a native format,
or None if it doesn't match (so the caller can fall through to old CK parsing).
"""
stripped = instance.strip()
for prefix, parser in NATIVE_PARSERS.items():
if stripped.startswith(prefix + "<"):
start = stripped.index("<") + 1
end = stripped.rindex(">")
params_str = stripped[start:end]
args = parse_instance_string(params_str)
return parser(args, instance_id, problem_name)
return None
def parse_fwd_instances(instances, problem_name):
convs = []
for instance_id, instance in enumerate(instances):
if instance.find("#") != -1 or instance.find(";") != -1:
continue
native = try_parse_native_instance(instance, instance_id, problem_name)
if native is not None:
convs.append(native)
continue
start = instance.index("<") + 1
end = instance.rindex(">")
params_str = instance[start:end]
args = parse_instance_string(params_str)
is_v3_instance = instance.find("Xdl_CShuffle_V3") != -1
split_image = instance.find("Large_Tensor") != -1
if is_v3_instance:
spec = args[14]
block_size = int(args[16])
m_per_block = int(args[17])
n_per_block = int(args[18])
k_per_block = int(args[19])
k1 = int(args[20])
m_per_xdl = int(args[22])
n_per_xdl = int(args[23])
m_xdl_per_wave = int(args[24])
n_xdl_per_wave = int(args[25])
a_scalar_per_vector = int(args[30])
b_scalar_per_vector = int(args[37])
c_scalar_per_vector = int(args[43])
scheduler = args[44]
pipeline_version = args[45]
direct_load = args[48] == "true"
num_groups_to_merge = int(args[49])
else:
spec = args[14]
block_size = int(args[17])
m_per_block = int(args[18])
n_per_block = int(args[19])
k_per_block = int(args[20])
k1 = int(args[21])
m_per_xdl = int(args[23])
n_per_xdl = int(args[24])
m_xdl_per_wave = int(args[25])
n_xdl_per_wave = int(args[26])
a_scalar_per_vector = int(args[31])
b_scalar_per_vector = int(args[38])
c_scalar_per_vector = int(args[44])
scheduler = "Intrawave"
pipeline_version = "v1"
direct_load = 0
num_groups_to_merge = 0 if split_image else int(args[48])
double_smem_buffer = pipeline_version == "v4"
num_wave_groups = 1
# Replace pipeline if Direct Load
if direct_load:
if pipeline_version == "v1":
pipeline_version = "ASYNC_V1"
elif pipeline_version == "v4":
pipeline_version = "ASYNC_V4"
else:
raise RuntimeError(
f"{pipeline_version} not supported pipeline for direct load"
)
else:
pipeline_version = pipeline_version.upper()
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
warp_size = 64
k_warp = int(block_size / (warp_size * m_warp * n_warp))
dtype = get_dtype(problem_name)
k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block)
if split_image:
print(
f"Skipping instance {instance_id} with split_image since it's not supported yet."
)
continue
if pipeline_version == "V5":
print(
f"Skipping instance {instance_id} with V5 since it's not supported yet."
)
continue
if pipeline_version == "ASYNC_V4":
print(
f"Skipping instance {instance_id} with ASYNC_V4 since it's not supported yet."
)
continue
is_two_stage = False
conv = ConvInstanceTemplateParams(
spec,
[m_per_block, n_per_block, k_per_block],
[m_warp, n_warp, k_warp],
[m_per_xdl, n_per_xdl, k_per_xdl],
double_smem_buffer,
num_wave_groups,
is_two_stage,
pipeline_version,
scheduler,
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],
num_groups_to_merge,
split_image,
False,
instance_id,
)
convs.append(conv)
return convs
def parse_bwd_weight_instances(instances, problem_name):
convs = []
for instance_id, instance in enumerate(instances):
if instance.find("#") != -1 or instance.find(";") != -1:
continue
native = try_parse_native_instance(instance, instance_id, problem_name)
if native is not None:
if native.streamk_enabled and get_dtype(problem_name) == "float" and native.pipeline_version.find("ASYNC") != -1:
print(f"Skipping instance {instance_id} with streamk, async, float since it's not supported yet.")
continue
convs.append(native)
continue
device_op_name = instance.split("<")[0]
start = instance.index("<") + 1
end = instance.rindex(">")
params_str = instance[start:end]
args = parse_instance_string(params_str)
direct_load = False
is_v3_instance = instance.find("Xdl_CShuffleV3") != -1
is_two_stage_instance = instance.find("TwoStage") != -1
is_explicit_gemm = device_op_name.find("Explicit") != -1
if is_explicit_gemm:
gemm_params = device_op_name = (
instance.split("<")[2].split(">")[1].split(",")
)
args = [param.split(":")[1].strip() for param in gemm_params]
spec = "Filter1x1Stride1Pad0"
block_size = int(args[0])
mnk_per_block = args[1].split("x")
m_per_block = int(mnk_per_block[0])
n_per_block = int(mnk_per_block[1])
k_per_block = int(mnk_per_block[2])
wave_tile = args[2].split("x")
m_per_xdl = int(wave_tile[0])
n_per_xdl = int(wave_tile[1])
k1_values = args[3].split("x")
ak1 = int(k1_values[0])
bk1 = int(k1_values[1])
k1 = min(ak1, bk1)
wave_map = args[4].split("x")
m_xdl_per_wave = int(wave_map[0])
n_xdl_per_wave = int(wave_map[1])
vector_read = args[5].split("x")
a_scalar_per_vector = int(vector_read[0])
b_scalar_per_vector = int(vector_read[1])
c_scalar_per_vector_seq = [
int(x)
for x in vector_read[2].strip("Seq").strip("(").strip(")").split(",")
]
if len(set(c_scalar_per_vector_seq)) != 1:
raise RuntimeError(
f"c_scalar_per_vector must be the same across all waves for instance {instance_id} with device op {device_op_name}. Found values: {c_scalar_per_vector_seq}"
)
c_scalar_per_vector = c_scalar_per_vector_seq[0]
num_groups_to_merge = 1
# Block GEMM pipeline parameters
block_gemm_pipeline_scheduler = args[6]
blk_gemm_pipeline_version = args[7]
else:
spec = args[11]
block_size = int(args[12])
m_per_block = int(args[13])
n_per_block = int(args[14])
k1 = int(args[16])
m_per_xdl = int(args[17])
n_per_xdl = int(args[18])
m_xdl_per_wave = int(args[19])
n_xdl_per_wave = int(args[20])
a_scalar_per_vector = int(args[25])
b_scalar_per_vector = int(args[32])
c_scalar_per_vector = int(args[38])
if is_v3_instance or is_two_stage_instance:
k_per_block = int(args[15])
else:
k0_per_block = int(args[15])
k_per_block = k0_per_block * k1
if is_v3_instance:
if len(args) != 45:
raise RuntimeError(
f"Wrong number of parameters in the V3 XDL CShuffle instance string: {instance}"
)
direct_load = int(args[43]) == 1
num_groups_to_merge = int(args[44])
# Block GEMM pipeline parameters
block_gemm_pipeline_scheduler = args[39]
blk_gemm_pipeline_version = args[40]
elif is_two_stage_instance:
if len(args) != 46:
raise RuntimeError(
f"Wrong number of parameters in the TwoStage instance string: {instance}\n"
+ f"Expected 46 parameters for TwoStage instance. Found {len(args)} parameters."
)
num_groups_to_merge = args[41]
# Block GEMM pipeline parameters
block_gemm_pipeline_scheduler = args[39]
blk_gemm_pipeline_version = args[40]
else:
# Regular V1 XDL CShuffle instance
if len(args) != 43:
raise RuntimeError(
f"Wrong number of parameters in the XDL CShuffle instance string: {instance}\n"
+ f"Expected 43 parameters for V1 instance. Found {len(args)} parameters."
)
num_groups_to_merge = 1
# Block GEMM pipeline parameters
block_gemm_pipeline_scheduler = "Intrawave"
blk_gemm_pipeline_version = "v1"
# Common part to all solvers.
# Sanity check for Block GEMM pipeline parameters
# Scheduler must be either Intrawave or Interwave.
# Version must be from v1 to v5
if block_gemm_pipeline_scheduler not in ["Intrawave", "Interwave"]:
raise RuntimeError(
f"Invalid Block GEMM pipeline scheduler: {block_gemm_pipeline_scheduler} in instance: {instance}"
)
if blk_gemm_pipeline_version not in ["v1", "v2", "v3", "v4", "v5"]:
raise RuntimeError(
f"Invalid Block GEMM pipeline version: {blk_gemm_pipeline_version} in instance: {instance}"
)
split_image = instance.find("Large") != -1
double_smem_buffer = blk_gemm_pipeline_version == "v4"
num_wave_groups = 1
scheduler = block_gemm_pipeline_scheduler
pipeline_version = blk_gemm_pipeline_version.upper()
# OLd CK pipeline version V5 maps to V6 for CK Tile
if pipeline_version == "V5":
pipeline_version = "V6"
if direct_load:
if pipeline_version == "V1":
pipeline_version = "ASYNC_V1"
elif pipeline_version == "V4":
pipeline_version = "ASYNC_V4"
else:
raise RuntimeError(
f"Not supported pipeline for direct load: pipeline_version={pipeline_version} in instance: {instance}"
)
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
warp_size = 64
k_warp = int(block_size / (warp_size * m_warp * n_warp))
dtype = get_dtype(problem_name)
k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block)
if not check_vectors(
a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector
):
print(
f"Skipping instance {instance_id} with irregular load since it's not supported yet."
)
continue
if pipeline_version == "V6":
print(
f"Skipping instance {instance_id} with V6 since it's not supported yet."
)
continue
if is_explicit_gemm:
if dtype != "float" and c_scalar_per_vector % 2 != 0:
is_two_stage_instance = True
conv = ConvInstanceTemplateParams(
spec,
[m_per_block, n_per_block, k_per_block],
[m_warp, n_warp, k_warp],
[m_per_xdl, n_per_xdl, k_per_xdl],
double_smem_buffer,
num_wave_groups,
is_two_stage_instance,
pipeline_version,
scheduler,
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],
num_groups_to_merge,
split_image,
is_explicit_gemm,
instance_id,
)
convs.append(conv)
return convs
def parse_bwd_data_instances(instances, problem_name):
convs = []
for instance_id, instance in enumerate(instances):
if instance.find("#") != -1 or instance.find(";") != -1:
continue
native = try_parse_native_instance(instance, instance_id, problem_name)
if native is not None:
convs.append(native)
continue
start = instance.index('<') + 1
end = instance.rindex('>')
params_str = instance[start:end]
args = parse_instance_string(params_str)
is_v1_instance = instance.find("Xdl_CShuffle<") != -1
if is_v1_instance:
if len(args) != 51:
raise RuntimeError(f"Wrong number of parameters in the V1 XDL CShuffle instance string: {instance}\n" +
f"Expected 51 parameters for V1 instance. Found {len(args)} parameters.")
else:
raise RuntimeError(f"Only V1 XDL CShuffle instances are supported for backward data. Found instance: {instance}")
spec = args[13]
block_size = int(args[17])
m_per_block = int(args[18])
n_per_block = int(args[19])
k_per_block = int(args[20])
ak1 = int(args[21])
bk1 = int(args[22])
m_per_xdl = int(args[23])
n_per_xdl = int(args[24])
m_xdl_per_wave = int(args[25])
n_xdl_per_wave = int(args[26])
a_scalar_per_vector = int(args[31])
b_scalar_per_vector = int(args[38])
c_scalar_per_vector = int(args[44])
if ak1 != bk1:
raise RuntimeError(f"Not supported instance {instance_id} since ak1 != bk1. ak1: {ak1}, bk1: {bk1} in instance: {instance}")
k1 = min(ak1, bk1)
# TODO: Do we need split image for 3D bwd data convs?
split_image = False
# Default optimization parameters
num_groups_to_merge = 1
is_two_stage_instance = False
is_explicit_gemm = False
num_wave_groups = 1
direct_load = False
# Block GEMM pipeline parameters
block_gemm_pipeline_scheduler = args[46]
if block_gemm_pipeline_scheduler == "Default":
block_gemm_pipeline_scheduler = "Intrawave"
blk_gemm_pipeline_version = "v1"
if block_gemm_pipeline_scheduler == "Interwave":
blk_gemm_pipeline_version = "v1"
# Sanity check for Block GEMM pipeline parameters
# Scheduler must be either Intrawave or Interwave.
# Version must be from v1 to v5
if block_gemm_pipeline_scheduler not in ["Intrawave", "Interwave"]:
raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {block_gemm_pipeline_scheduler} in instance: {instance}")
if blk_gemm_pipeline_version not in ["v1", "v2", "v3", "v4", "v5"]:
raise RuntimeError(f"Invalid Block GEMM pipeline version: {blk_gemm_pipeline_version} in instance: {instance}")
double_smem_buffer = blk_gemm_pipeline_version == "v4"
scheduler = block_gemm_pipeline_scheduler
pipeline_version = blk_gemm_pipeline_version.upper()
# Old CK pipeline version V5 maps to V6 for CK Tile
if pipeline_version == "V5":
pipeline_version = "V6"
if direct_load:
if pipeline_version == "V1":
pipeline_version = "ASYNC_V1"
elif pipeline_version == "V4":
pipeline_version = "ASYNC_V4"
else:
raise RuntimeError(
f"Not supported pipeline for direct load: pipeline_version={pipeline_version} in instance: {instance}"
)
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
warp_size = 64
k_warp = int(block_size / (warp_size * m_warp * n_warp))
dtype = get_dtype(problem_name)
k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block)
if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False:
print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.")
continue
if pipeline_version == "V6":
print(f"Skipping instance {instance_id} with V6 since it's not supported yet.")
continue
if a_scalar_per_vector > (m_per_block * k_per_block) // block_size or b_scalar_per_vector > (n_per_block * k_per_block) // block_size:
print(f"Skipping instance {instance_id} because current scalar per vector exceedes tile size")
continue
conv = ConvInstanceTemplateParams(
spec,
[m_per_block, n_per_block, k_per_block],
[m_warp, n_warp, k_warp],
[m_per_xdl, n_per_xdl, k_per_xdl],
double_smem_buffer,
num_wave_groups,
is_two_stage_instance,
pipeline_version,
scheduler,
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],
num_groups_to_merge,
split_image,
is_explicit_gemm,
instance_id,
)
convs.append(conv)
return convs
def get_signature_base(config):
"""Extract layout_dtype from config name, stripping variant suffixes.
Config names follow {layout}_{dtype}[_{variant}], e.g. nhwgc_fp16_streamk.
The signature is determined by layout and dtype only.
"""
parts = config.split("_")
return f"{parts[0]}_{parts[1]}"
def generate_instances_fwd(
instances, problem_name, config, filter_pattern, instances_path
):
direction = "forward"
signature_name = f"SIGNATURE_{get_signature_base(config).upper()}_FWD"
instances = parse_fwd_instances(instances, problem_name)
generate_calls_inc(instances, problem_name, direction, filter_pattern)
generate_defs_inc(
instances, problem_name, signature_name, direction, filter_pattern
)
generate_conv_cpp(
instances,
problem_name,
config,
direction,
signature_name,
filter_pattern,
instances_path,
)
def generate_instances_bwd_weight(
instances, problem_name, config, filter_pattern, instances_path
):
direction = "backward_weight"
signature_name = f"SIGNATURE_{get_signature_base(config).upper()}_BWD_WEIGHT"
instances = parse_bwd_weight_instances(instances, problem_name)
generate_calls_inc(instances, problem_name, direction, filter_pattern)
generate_defs_inc(
instances, problem_name, signature_name, direction, filter_pattern
)
generate_conv_cpp(
instances,
problem_name,
config,
direction,
signature_name,
filter_pattern,
instances_path,
)
def generate_instances_bwd_data(
instances, problem_name, config, filter_pattern, instances_path
):
direction = "backward_data"
signature_name = f"SIGNATURE_{get_signature_base(config).upper()}_BWD_DATA"
instances = parse_bwd_data_instances(instances, problem_name)
generate_calls_inc(instances, problem_name, direction, filter_pattern)
generate_defs_inc(
instances, problem_name, signature_name, direction, filter_pattern
)
generate_conv_cpp(
instances,
problem_name,
config,
direction,
signature_name,
filter_pattern,
instances_path,
)
def process_direction(
configs, direction, generate_func, configs_prefix, filter_pattern, instances_path
):
"""Helper function to process a single direction."""
for config in configs:
instances = []
generate_dir = Path(__file__).resolve().parent
config_path = (
f"{generate_dir}/configs/{direction}/{configs_prefix}/{config}.conf"
)
with open(config_path, "r") as file:
instances = file.readlines()
# Determine problem name based on direction
if direction == "forward":
problem_name = f"grouped_convolution_forward_tile_{config}"
elif direction == "backward_weight":
problem_name = f"grouped_convolution_backward_weight_tile_{config}"
elif direction == "backward_data":
problem_name = f"grouped_convolution_backward_data_tile_{config}"
else:
raise RuntimeError(f"Unknown direction: {direction}")
generate_func(instances, problem_name, config, filter_pattern, instances_path)
# ---------------------------------------------------------------------------
# Depthwise forward generation
# ---------------------------------------------------------------------------
DEPTHWISE_CONFIGS = [
{
"name": "ngchw_depthwise_fp32",
"conf": "ngchw_depthwise.conf",
"signature": "SIGNATURE_NGCHW_FP32_FWD",
},
{
"name": "ngchw_depthwise_fp16",
"conf": "ngchw_depthwise.conf",
"signature": "SIGNATURE_NGCHW_FP16_FWD",
},
{
"name": "ngchw_depthwise_bf16",
"conf": "ngchw_depthwise.conf",
"signature": "SIGNATURE_NGCHW_BF16_FWD",
},
]
def parse_depthwise_config(conf_path: Path) -> list:
"""Parse a depthwise config file.
Accepts the ``GroupedConvolutionForwardDepthwise<...>`` format.
Returns a list of 12-element integer lists:
[TileH, TileW, Filter, StrH, StrW, PadH, PadW,
NBatch, SubTileH, SubTileW, InVecSize, OutVecSize]
"""
instances = []
for raw in conf_path.read_text().splitlines():
line = raw.strip()
if not line or line.startswith("#"):
continue
if "<" in line and ">" in line:
start = line.index("<") + 1
end = line.rindex(">")
line = line[start:end]
params = [int(x.strip()) for x in line.split(",")]
if len(params) != 12:
raise ValueError(
f"Expected 12 parameters per depthwise instance, got {len(params)}: {raw!r}"
)
instances.append(params)
return instances
def generate_depthwise_cpp(params: list, instance_name: str, signature: str, cpp_out: Path) -> None:
tile_h, tile_w, filt, str_h, str_w, pad_h, pad_w, nbatch, sub_h, sub_w, in_vec, out_vec = params
parent_dir = Path(__file__).resolve().parent
template_file = parent_dir / "include/grouped_convolution_depthwise_tile.cpp.in"
content = template_file.read_text()
content = content.replace("gen_signature", signature)
content = content.replace("gen_instance_name", instance_name)
content = content.replace("gen_block_size", "64")
content = content.replace("gen_tile_h", str(tile_h))
content = content.replace("gen_tile_w", str(tile_w))
content = content.replace("gen_filter_h", str(filt))
content = content.replace("gen_filter_w", str(filt))
content = content.replace("gen_stride_h", str(str_h))
content = content.replace("gen_stride_w", str(str_w))
content = content.replace("gen_dilation_h", "1")
content = content.replace("gen_dilation_w", "1")
content = content.replace("gen_pad_h", str(pad_h))
content = content.replace("gen_pad_w", str(pad_w))
content = content.replace("gen_nbatch", str(nbatch))
content = content.replace("gen_subtile_h", str(sub_h))
content = content.replace("gen_subtile_w", str(sub_w))
content = content.replace("gen_in_vec", str(in_vec))
content = content.replace("gen_out_vec", str(out_vec))
cpp_out.write_text(content)
def generate_depthwise_defs_inc(instances: list, config_name: str, signature: str, inc_path: Path) -> None:
lines = []
for i in range(len(instances)):
name = f"grouped_convolution_forward_tile_{config_name}_{i}"
lines.append(
f"std::tuple<bool, float, std::string> run_{name}(\n"
f" const ckt::Args<{signature}>& args,\n"
f" const ckt::Inputs<{signature}>& inputs,\n"
f" const ckt::Outputs<{signature}>& outputs,\n"
f" const ck_tile::stream_config& s_conf);"
)
inc_path.write_text("\n".join(lines) + "\n")
def generate_depthwise_calls_inc(instances: list, config_name: str, calls_path: Path) -> None:
lines = []
for i in range(len(instances)):
name = f"grouped_convolution_forward_tile_{config_name}_{i}"
lines.append(f"run_alg(run_{name});")
calls_path.write_text("\n".join(lines) + "\n")
def process_depthwise_forward(configs_prefix: str, instances_path: str) -> None:
"""Generate all depthwise forward instances."""
generate_dir = Path(__file__).resolve().parent
conf_dir = generate_dir / "configs/forward" / configs_prefix
inc_dir = generate_dir / "instances" / "forward"
cpp_base = Path(instances_path) / "forward"
for cfg in DEPTHWISE_CONFIGS:
name = cfg["name"]
conf_path = conf_dir / cfg["conf"]
signature = cfg["signature"]
if not conf_path.exists():
print(f" Skipping {name}: config not found at {conf_path}")
continue
instances = parse_depthwise_config(conf_path)
print(f"Processing {name}: {len(instances)} instances ...")
cpp_dir = cpp_base / name
cpp_dir.mkdir(parents=True, exist_ok=True)
for i, params in enumerate(instances):
instance_name = f"grouped_convolution_forward_tile_{name}_{i}"
generate_depthwise_cpp(params, instance_name, signature,
cpp_dir / f"{instance_name}.cpp")
generate_depthwise_defs_inc(instances, name, signature,
inc_dir / f"grouped_convolution_forward_tile_{name}.inc")
generate_depthwise_calls_inc(instances, name,
inc_dir / f"grouped_convolution_forward_tile_{name}_calls.inc")
print(f" -> {cpp_dir} ({len(instances)} .cpp files)")
if __name__ == "__main__":
fwd_configs = [
"nhwgc_fp32",
"nhwgc_fp16",
"nhwgc_bf16",
"ndhwgc_fp32",
"ndhwgc_fp16",
"ndhwgc_bf16",
]
bwd_weight_configs = [
"nhwgc_fp32",
"nhwgc_fp16",
"nhwgc_bf16",
"ndhwgc_fp32",
"ndhwgc_fp16",
"ndhwgc_bf16",
"nhwgc_fp32_streamk",
"nhwgc_fp16_streamk",
"nhwgc_bf16_streamk",
"ndhwgc_fp32_streamk",
"ndhwgc_fp16_streamk",
"ndhwgc_bf16_streamk",
]
bwd_data_configs = [
"nhwgc_fp32",
"nhwgc_fp16",
"nhwgc_bf16",
"ndhwgc_fp32",
"ndhwgc_fp16",
"ndhwgc_bf16",
]
parser = argparse.ArgumentParser(
description="Generate grouped conv CK Tile instances."
)
parser.add_argument(
"--filter_pattern",
type=str,
default="convolution",
help="Filter pattern for configs.",
)
parser.add_argument(
"--mode",
choices=["compilation", "tests", "profiler"],
type=str,
default="profiler",
help="Generator modes. compilation - empty instance list, tests - limited instance list, profiler - generate all instances",
)
parser.add_argument(
"--direction",
choices=["forward", "backward_weight", "backward_data", "all"],
type=str,
default="all",
help="Convolution direction for which to generate instances.",
)
parser.add_argument(
"--instances_dir",
type=str,
default="../build/experimental/grouped_convolution_tile_instances",
help="Directory store generated instances.",
)
args = parser.parse_args()
# apply empty filter
if args.mode == "compilation":
args.filter_pattern = "empty"
configs_prefix = "profiler"
elif args.mode == "tests":
configs_prefix = "tests"
elif args.mode == "profiler":
configs_prefix = "profiler"
else:
raise RuntimeError("wrong mode")
copy_includes(args.instances_dir)
match args.direction:
case "forward":
process_direction(fwd_configs, args.direction, generate_instances_fwd, configs_prefix, args.filter_pattern, args.instances_dir)
process_depthwise_forward(configs_prefix, args.instances_dir)
case "backward_weight":
process_direction(
bwd_weight_configs,
args.direction,
generate_instances_bwd_weight,
configs_prefix,
args.filter_pattern,
args.instances_dir,
)
case "backward_data":
process_direction(
bwd_data_configs,
args.direction,
generate_instances_bwd_data,
configs_prefix,
args.filter_pattern,
args.instances_dir,
)
case "all":
process_direction(fwd_configs, "forward", generate_instances_fwd, configs_prefix, args.filter_pattern, args.instances_dir)
process_depthwise_forward(configs_prefix, args.instances_dir)
process_direction(bwd_weight_configs, "backward_weight", generate_instances_bwd_weight, configs_prefix, args.filter_pattern, args.instances_dir)
process_direction(bwd_data_configs, "backward_data", generate_instances_bwd_data, configs_prefix, args.filter_pattern, args.instances_dir)