mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK][CK Tile] Grouped Convolution Backward Weight set of fixes (#5387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Grouped Convolution Backward Weight split k fixes for CK tile kernels ## Technical Details - get k batch from kargs to get deduced k batch - multiply zeroing size by data type size - disable v6 (producing a incorrect results) ## Test Plan test_grouped_convnd_bwd_weight_tile ## Test Result Pass ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
613 lines
24 KiB
Python
Executable File
613 lines
24 KiB
Python
Executable File
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
class ConvInstanceTemplateParams:
|
|
def __init__(
|
|
self,
|
|
specialization,
|
|
tile_size,
|
|
warps,
|
|
warp_tile,
|
|
double_smem_buffer,
|
|
num_wave_groups,
|
|
is_two_stage_instance,
|
|
pipeline_version,
|
|
scheduler,
|
|
scalar_per_vector,
|
|
num_groups_to_merge,
|
|
split_image,
|
|
explicit_gemm,
|
|
id,
|
|
):
|
|
self.specialization = specialization
|
|
self.tile_size = tile_size
|
|
self.warps = warps
|
|
self.warp_tile = warp_tile
|
|
self.double_smem_buffer = double_smem_buffer
|
|
self.num_wave_groups = num_wave_groups
|
|
self.is_two_stage_instance = is_two_stage_instance
|
|
self.pipeline_version = pipeline_version
|
|
self.scheduler = scheduler
|
|
self.scalar_per_vector = scalar_per_vector
|
|
self.num_groups_to_merge = num_groups_to_merge
|
|
self.split_image = split_image
|
|
self.explicit_gemm = explicit_gemm
|
|
self.id = id
|
|
|
|
def get_optimizations(self):
|
|
explicit_gemm = "true" if self.explicit_gemm else "false"
|
|
split_image = "true" if self.split_image else "false"
|
|
num_groups_to_merge = str(self.num_groups_to_merge)
|
|
two_stage_instance = "true" if self.is_two_stage_instance else "false"
|
|
return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}, .two_stage = {two_stage_instance}}}"
|
|
|
|
def get_specialization(self):
|
|
namespace = "ckb::TileConvSpecialization::"
|
|
if self.specialization == "Default" or self.specialization == "OddC":
|
|
return namespace + "DEFAULT"
|
|
if self.specialization == "Filter1x1Pad0":
|
|
return namespace + "FILTER_1X1_PAD0"
|
|
if self.specialization == "Filter1x1Stride1Pad0":
|
|
return namespace + "FILTER_1X1_STRIDE1_PAD0"
|
|
if self.specialization == "Filter3x3":
|
|
return namespace + "FILTER_3x3"
|
|
else:
|
|
raise RuntimeError("not supported specialization")
|
|
|
|
def get_thread_block(self):
|
|
return f"ckt::TileThreadBlock{{.tile_size = {{.m = {self.tile_size[0]}, .n = {self.tile_size[1]}, .k = {self.tile_size[2]}}}}}"
|
|
|
|
def get_block_gemm_desc(self):
|
|
double_smem_buffer = "true" if self.double_smem_buffer else "false"
|
|
scheduler = (
|
|
"INTRAWAVE" if self.scheduler.find("Intrawave") != -1 else "INTERWAVE"
|
|
)
|
|
return f"""ckt::TileBlockGemm{{
|
|
.warps = {{.m = {self.warps[0]}, .n = {self.warps[1]}, .k = {self.warps[2]}}},
|
|
.warp_tile = {{.m = {self.warp_tile[0]}, .n = {self.warp_tile[1]}, .k = {self.warp_tile[2]}}},
|
|
.double_smem_buffer = {double_smem_buffer},
|
|
.num_wave_groups = {self.num_wave_groups},
|
|
.pipeline_version = ckb::PipelineVersion::{self.pipeline_version},
|
|
.scheduler = ckb::PipelineScheduler::{scheduler}}}"""
|
|
|
|
def get_block_transfer(self):
|
|
return f"""ckt::TileTransfer{{.a_scalar_per_vector = {self.scalar_per_vector[0]},
|
|
.b_scalar_per_vector = {self.scalar_per_vector[1]}, .c_scalar_per_vector = {self.scalar_per_vector[2]}}}"""
|
|
|
|
|
|
def get_dtype(problem_name):
|
|
if problem_name.find("fp32") != -1:
|
|
return "float"
|
|
if problem_name.find("fp16") != -1:
|
|
return "ck_tile::half_t"
|
|
if problem_name.find("bf16") != -1:
|
|
return "ck_tile::bf16_t"
|
|
else:
|
|
raise RuntimeError("Cannot parse data type from problem name: " + problem_name)
|
|
|
|
def get_k_mfma(dtype, m_per_xdl, n_per_xdl):
|
|
if m_per_xdl != n_per_xdl:
|
|
raise RuntimeError("Not supported")
|
|
if dtype == "float":
|
|
if m_per_xdl == 32:
|
|
return 2
|
|
else:
|
|
return 4
|
|
else:
|
|
if m_per_xdl == 32:
|
|
return 8
|
|
else:
|
|
return 16
|
|
|
|
def check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector):
|
|
if a_scalar_per_vector != 1 and a_scalar_per_vector % 2 != 0:
|
|
return False
|
|
if b_scalar_per_vector != 1 and b_scalar_per_vector % 2 != 0:
|
|
return False
|
|
if c_scalar_per_vector != 1 and c_scalar_per_vector % 2 != 0:
|
|
return False
|
|
return True
|
|
|
|
def parse_instance_string(instance_string):
|
|
"""Parse instance string, treating Seq(...) as a single parameter."""
|
|
params = []
|
|
current_param = ""
|
|
paren_depth = 0
|
|
|
|
for char in instance_string:
|
|
if char == '(':
|
|
paren_depth += 1
|
|
current_param += char
|
|
elif char == ')':
|
|
paren_depth -= 1
|
|
current_param += char
|
|
elif char == ',' and paren_depth == 0:
|
|
# Only split on comma if we're not inside parentheses
|
|
params.append(current_param.strip())
|
|
current_param = ""
|
|
else:
|
|
current_param += char
|
|
|
|
# Add the last parameter
|
|
if current_param.strip():
|
|
params.append(current_param.strip())
|
|
|
|
return params
|
|
|
|
|
|
def generate_calls_inc(instances, problem_name, direction, filter_pattern):
|
|
generate_dir = Path(__file__).resolve().parent
|
|
output_dir = Path(f"{generate_dir}/instances/{direction}")
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(f"{generate_dir}/instances/{direction}/{problem_name}_calls.inc", "w") as f:
|
|
if problem_name.find(filter_pattern) == -1:
|
|
return
|
|
for instance in instances:
|
|
instance_name = problem_name + "_" + str(instance.id)
|
|
f.write(f"run_alg(run_{instance_name});\n")
|
|
|
|
|
|
def generate_defs_inc(instances, problem_name, signature, direction, filter_pattern):
|
|
generate_dir = Path(__file__).resolve().parent
|
|
with open(f"{generate_dir}/instances/{direction}/{problem_name}.inc", "w") as f:
|
|
if problem_name.find(filter_pattern) == -1:
|
|
return
|
|
for instance in instances:
|
|
instance_name = problem_name + "_" + str(instance.id)
|
|
f.write(
|
|
f"std::tuple<bool, float, std::string> run_{instance_name}(\n"
|
|
f" const ckt::Args<{signature}>& args,\n"
|
|
f" const ckt::Inputs<{signature}>& inputs,\n"
|
|
f" const ckt::Outputs<{signature}>& outputs,\n"
|
|
f" const ck_tile::stream_config& s_conf);\n"
|
|
)
|
|
|
|
|
|
def generate_conv_cpp(
|
|
instances, problem_name, config, direction, signature_name, filter_pattern):
|
|
for instance in instances:
|
|
if problem_name.find(filter_pattern) == -1:
|
|
break
|
|
instance_name = problem_name + "_" + str(instance.id)
|
|
generate_dir = Path(__file__).resolve().parent
|
|
directory_path = Path(f"{generate_dir}/instances/{direction}/{config}")
|
|
directory_path.mkdir(parents=True, exist_ok=True)
|
|
template_file = "grouped_convolution_tile.cpp.in"
|
|
|
|
with open(f"{generate_dir}/instances/{template_file}", "r",) as f:
|
|
content = f.read()
|
|
|
|
content = content.replace("gen_signature", signature_name)
|
|
content = content.replace("gen_instance_name", instance_name)
|
|
content = content.replace("gen_specialization", instance.get_specialization())
|
|
content = content.replace("gen_thread_block", instance.get_thread_block())
|
|
content = content.replace("gen_block_gemm_desc", instance.get_block_gemm_desc())
|
|
content = content.replace("gen_block_transfer", instance.get_block_transfer())
|
|
content = content.replace("gen_optimizations", instance.get_optimizations())
|
|
|
|
with open(f"{generate_dir}/instances/{direction}/{config}/{instance_name}.cpp","w",) as f:
|
|
f.write(content)
|
|
|
|
|
|
def parse_fwd_instances(instances, problem_name):
|
|
convs = []
|
|
for instance_id, instance in enumerate(instances):
|
|
if instance.find("#") != -1 or instance.find(";") != -1:
|
|
continue
|
|
start = instance.index('<') + 1
|
|
end = instance.rindex('>')
|
|
params_str = instance[start:end]
|
|
args = parse_instance_string(params_str)
|
|
|
|
is_v3_instance = instance.find("Xdl_CShuffle_V3") != -1
|
|
split_image = instance.find("Large_Tensor") != -1
|
|
|
|
if is_v3_instance:
|
|
spec = args[14]
|
|
block_size = int(args[16])
|
|
m_per_block = int(args[17])
|
|
n_per_block = int(args[18])
|
|
k_per_block = int(args[19])
|
|
k1 = int(args[20])
|
|
m_per_xdl = int(args[22])
|
|
n_per_xdl = int(args[23])
|
|
m_xdl_per_wave = int(args[24])
|
|
n_xdl_per_wave = int(args[25])
|
|
a_scalar_per_vector = int(args[30])
|
|
b_scalar_per_vector = int(args[37])
|
|
c_scalar_per_vector = int(args[43])
|
|
scheduler = args[44]
|
|
pipeline_version = args[45]
|
|
direct_load = args[48] == "true"
|
|
num_groups_to_merge = int(args[49])
|
|
else:
|
|
spec = args[14]
|
|
block_size = int(args[17])
|
|
m_per_block = int(args[18])
|
|
n_per_block = int(args[19])
|
|
k_per_block = int(args[20])
|
|
k1 = int(args[21])
|
|
m_per_xdl = int(args[23])
|
|
n_per_xdl = int(args[24])
|
|
m_xdl_per_wave = int(args[25])
|
|
n_xdl_per_wave = int(args[26])
|
|
a_scalar_per_vector = int(args[31])
|
|
b_scalar_per_vector = int(args[38])
|
|
c_scalar_per_vector = int(args[44])
|
|
scheduler = "Intrawave"
|
|
pipeline_version = "v1"
|
|
direct_load = 0
|
|
num_groups_to_merge = 0 if split_image else int(args[48])
|
|
|
|
double_smem_buffer = pipeline_version == "v4"
|
|
num_wave_groups = 1
|
|
# Replace pipeline if Direct Load
|
|
if direct_load:
|
|
if pipeline_version == "v1":
|
|
pipeline_version = "ASYNC_V1"
|
|
elif pipeline_version == "v4":
|
|
pipeline_version = "ASYNC_V4"
|
|
else:
|
|
raise RuntimeError(f"{pipeline_version} not supported pipeline for direct load")
|
|
else:
|
|
pipeline_version = pipeline_version.upper()
|
|
|
|
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
|
|
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
|
|
warp_size = 64
|
|
k_warp = int(block_size / (warp_size * m_warp * n_warp))
|
|
dtype = get_dtype(problem_name)
|
|
k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl))
|
|
|
|
if split_image:
|
|
print(f"Skipping instance {instance_id} with split_image since it's not supported yet.")
|
|
continue
|
|
if pipeline_version == "V5":
|
|
print(f"Skipping instance {instance_id} with V5 since it's not supported yet.")
|
|
continue
|
|
if pipeline_version == "ASYNC_V4":
|
|
print(f"Skipping instance {instance_id} with ASYNC_V4 since it's not supported yet.")
|
|
continue
|
|
|
|
is_two_stage = False
|
|
|
|
conv = ConvInstanceTemplateParams(
|
|
spec,
|
|
[m_per_block, n_per_block, k_per_block],
|
|
[m_warp, n_warp, k_warp],
|
|
[m_per_xdl, n_per_xdl, k_per_xdl],
|
|
double_smem_buffer,
|
|
num_wave_groups,
|
|
is_two_stage,
|
|
pipeline_version,
|
|
scheduler,
|
|
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],
|
|
num_groups_to_merge,
|
|
split_image,
|
|
False,
|
|
instance_id,
|
|
)
|
|
convs.append(conv)
|
|
return convs
|
|
|
|
def parse_bwd_weight_instances(instances, problem_name):
|
|
convs = []
|
|
|
|
for instance_id, instance in enumerate(instances):
|
|
if instance.find("#") != -1 or instance.find(";") != -1:
|
|
continue
|
|
|
|
device_op_name = instance.split("<")[0]
|
|
start = instance.index('<') + 1
|
|
end = instance.rindex('>')
|
|
params_str = instance[start:end]
|
|
args = parse_instance_string(params_str)
|
|
|
|
is_v3_instance = instance.find("Xdl_CShuffleV3") != -1
|
|
is_two_stage_instance = instance.find("TwoStage") != -1
|
|
is_explicit_gemm = device_op_name.find("Explicit") != -1
|
|
|
|
if is_explicit_gemm:
|
|
gemm_params = device_op_name = instance.split("<")[2].split(">")[1].split(",")
|
|
args = [param.split(":")[1].strip() for param in gemm_params]
|
|
|
|
spec = "Default"
|
|
block_size = int(args[0])
|
|
|
|
mnk_per_block = args[1].split("x")
|
|
m_per_block = int(mnk_per_block[0])
|
|
n_per_block = int(mnk_per_block[1])
|
|
k_per_block = int(mnk_per_block[2])
|
|
|
|
wave_tile = args[2].split("x")
|
|
m_per_xdl = int(wave_tile[0])
|
|
n_per_xdl = int(wave_tile[1])
|
|
|
|
k1_values = args[3].split("x")
|
|
ak1 = int(k1_values[0])
|
|
bk1 = int(k1_values[1])
|
|
k1 = min(ak1, bk1)
|
|
|
|
wave_map = args[4].split("x")
|
|
m_xdl_per_wave = int(wave_map[0])
|
|
n_xdl_per_wave = int(wave_map[1])
|
|
|
|
vector_read = args[5].split("x")
|
|
a_scalar_per_vector = int(vector_read[0])
|
|
b_scalar_per_vector = int(vector_read[1])
|
|
c_scalar_per_vector_seq = [int(x) for x in vector_read[2].strip("Seq").strip("(").strip(")").split(",")]
|
|
|
|
if len(set(c_scalar_per_vector_seq)) != 1:
|
|
raise RuntimeError(f"c_scalar_per_vector must be the same across all waves for instance {instance_id} with device op {device_op_name}. Found values: {c_scalar_per_vector_seq}")
|
|
|
|
c_scalar_per_vector = c_scalar_per_vector_seq[0]
|
|
|
|
num_groups_to_merge = 1
|
|
|
|
# Block GEMM pipeline parameters
|
|
block_gemm_pipeline_scheduler = args[6]
|
|
blk_gemm_pipeline_version = args[7]
|
|
else:
|
|
spec = args[11]
|
|
block_size = int(args[12])
|
|
m_per_block = int(args[13])
|
|
n_per_block = int(args[14])
|
|
k1 = int(args[16])
|
|
m_per_xdl = int(args[17])
|
|
n_per_xdl = int(args[18])
|
|
m_xdl_per_wave = int(args[19])
|
|
n_xdl_per_wave = int(args[20])
|
|
a_scalar_per_vector = int(args[25])
|
|
b_scalar_per_vector = int(args[32])
|
|
c_scalar_per_vector = int(args[38])
|
|
|
|
if is_v3_instance or is_two_stage_instance:
|
|
k_per_block = int(args[15])
|
|
else:
|
|
k0_per_block = int(args[15])
|
|
k_per_block = k0_per_block * k1
|
|
|
|
if is_v3_instance:
|
|
if len(args) != 45:
|
|
raise RuntimeError(f"Wrong number of parameters in the V3 XDL CShuffle instance string: {instance}")
|
|
|
|
num_groups_to_merge = int(args[44])
|
|
|
|
# Block GEMM pipeline parameters
|
|
block_gemm_pipeline_scheduler = args[39]
|
|
blk_gemm_pipeline_version = args[40]
|
|
elif is_two_stage_instance:
|
|
if len(args) != 46:
|
|
raise RuntimeError(f"Wrong number of parameters in the TwoStage instance string: {instance}\n" +
|
|
f"Expected 46 parameters for TwoStage instance. Found {len(args)} parameters.")
|
|
|
|
num_groups_to_merge = args[41]
|
|
|
|
# Block GEMM pipeline parameters
|
|
block_gemm_pipeline_scheduler = args[39]
|
|
blk_gemm_pipeline_version = args[40]
|
|
|
|
else:
|
|
# Regular V1 XDL CShuffle instance
|
|
if len(args) != 43:
|
|
raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}\n" +
|
|
f"Expected 43 parameters for V1 instance. Found {len(args)} parameters.")
|
|
|
|
num_groups_to_merge = 1
|
|
|
|
# Block GEMM pipeline parameters
|
|
block_gemm_pipeline_scheduler = "Intrawave"
|
|
blk_gemm_pipeline_version = "v1"
|
|
|
|
# Common part to all solvers.
|
|
|
|
# Sanity check for Block GEMM pipeline parameters
|
|
# Scheduler must be either Intrawave or Interwave.
|
|
# Version must be from v1 to v5
|
|
if block_gemm_pipeline_scheduler not in ["Intrawave", "Interwave"]:
|
|
raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {block_gemm_pipeline_scheduler} in instance: {instance}")
|
|
if blk_gemm_pipeline_version not in ["v1", "v2", "v3", "v4", "v5"]:
|
|
raise RuntimeError(f"Invalid Block GEMM pipeline version: {blk_gemm_pipeline_version} in instance: {instance}")
|
|
|
|
split_image = instance.find("Large") != -1
|
|
double_smem_buffer = blk_gemm_pipeline_version == "v4"
|
|
num_wave_groups = 1
|
|
scheduler = block_gemm_pipeline_scheduler
|
|
pipeline_version = blk_gemm_pipeline_version.upper()
|
|
|
|
# OLd CK pipeline version V5 maps to V6 for CK Tile
|
|
if pipeline_version == "V5":
|
|
pipeline_version = "V6"
|
|
|
|
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
|
|
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
|
|
warp_size = 64
|
|
k_warp = int(block_size / (warp_size * m_warp * n_warp))
|
|
dtype = get_dtype(problem_name)
|
|
|
|
k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl))
|
|
|
|
if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False:
|
|
print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.")
|
|
continue
|
|
if pipeline_version == "V6":
|
|
print(f"Skipping instance {instance_id} with V6 since it's not supported yet.")
|
|
continue
|
|
|
|
conv = ConvInstanceTemplateParams(
|
|
spec,
|
|
[m_per_block, n_per_block, k_per_block],
|
|
[m_warp, n_warp, k_warp],
|
|
[m_per_xdl, n_per_xdl, k_per_xdl],
|
|
double_smem_buffer,
|
|
num_wave_groups,
|
|
is_two_stage_instance,
|
|
pipeline_version,
|
|
scheduler,
|
|
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],
|
|
num_groups_to_merge,
|
|
split_image,
|
|
is_explicit_gemm,
|
|
instance_id,
|
|
)
|
|
convs.append(conv)
|
|
|
|
return convs
|
|
|
|
def parse_bwd_data_instances(instances, problem_name):
|
|
convs = []
|
|
print("Parsing backward data instances is not supported yet, skipping all instances.")
|
|
# TODO: Implement parsing logic for backward data instances.
|
|
return convs
|
|
|
|
def generate_instances_fwd(instances, problem_name, config, filter_pattern):
|
|
direction = "forward"
|
|
signature_name = f"SIGNATURE_{config.upper()}_FWD"
|
|
instances = parse_fwd_instances(instances, problem_name)
|
|
generate_calls_inc(instances, problem_name, direction, filter_pattern)
|
|
generate_defs_inc(
|
|
instances,
|
|
problem_name,
|
|
signature_name,
|
|
direction,
|
|
filter_pattern,
|
|
)
|
|
generate_conv_cpp(
|
|
instances, problem_name, config, direction, signature_name, filter_pattern
|
|
)
|
|
|
|
def generate_instances_bwd_weight(instances, problem_name, config, filter_pattern):
|
|
direction = "backward_weight"
|
|
signature_name = f"SIGNATURE_{config.upper()}_BWD_WEIGHT"
|
|
instances = parse_bwd_weight_instances(instances, problem_name)
|
|
generate_calls_inc(instances, problem_name, direction, filter_pattern)
|
|
generate_defs_inc(
|
|
instances,
|
|
problem_name,
|
|
signature_name,
|
|
direction,
|
|
filter_pattern,
|
|
)
|
|
generate_conv_cpp(
|
|
instances, problem_name, config, direction, signature_name, filter_pattern
|
|
)
|
|
|
|
def generate_instances_bwd_data(instances, problem_name, config, filter_pattern):
|
|
direction = "backward_data"
|
|
signature_name = f"SIGNATURE_{config.upper()}_BWD_DATA"
|
|
instances = parse_bwd_data_instances(instances, problem_name)
|
|
generate_calls_inc(instances, problem_name, direction, filter_pattern)
|
|
generate_defs_inc(
|
|
instances,
|
|
problem_name,
|
|
signature_name,
|
|
direction,
|
|
filter_pattern,
|
|
)
|
|
generate_conv_cpp(
|
|
instances, problem_name, config, direction, signature_name, filter_pattern
|
|
)
|
|
|
|
def process_direction(configs, direction, generate_func, configs_prefix, filter_pattern):
|
|
"""Helper function to process a single direction."""
|
|
for config in configs:
|
|
instances = []
|
|
generate_dir = Path(__file__).resolve().parent
|
|
config_path = f"{generate_dir}/configs/{direction}/{configs_prefix}/{config}.conf"
|
|
with open(config_path, "r") as file:
|
|
instances = file.readlines()
|
|
|
|
# Determine problem name based on direction
|
|
if direction == "forward":
|
|
problem_name = f"grouped_convolution_forward_tile_{config}"
|
|
elif direction == "backward_weight":
|
|
problem_name = f"grouped_convolution_backward_weight_tile_{config}"
|
|
elif direction == "backward_data":
|
|
problem_name = f"grouped_convolution_backward_data_tile_{config}"
|
|
else:
|
|
raise RuntimeError(f"Unknown direction: {direction}")
|
|
|
|
generate_func(instances, problem_name, config, filter_pattern)
|
|
|
|
if __name__ == "__main__":
|
|
fwd_configs = [
|
|
"nhwgc_fp32",
|
|
"nhwgc_fp16",
|
|
"nhwgc_bf16",
|
|
"ndhwgc_fp32",
|
|
"ndhwgc_fp16",
|
|
"ndhwgc_bf16",
|
|
]
|
|
|
|
# FP32 doesn't work for bwd weigth currently
|
|
bwd_weight_configs = [
|
|
"nhwgc_fp32",
|
|
"nhwgc_fp16",
|
|
"nhwgc_bf16",
|
|
"ndhwgc_fp32",
|
|
"ndhwgc_fp16",
|
|
"ndhwgc_bf16",
|
|
]
|
|
|
|
bwd_data_configs = [
|
|
"nhwgc_fp32",
|
|
"nhwgc_fp16",
|
|
"nhwgc_bf16",
|
|
"ndhwgc_fp32",
|
|
"ndhwgc_fp16",
|
|
"ndhwgc_bf16",
|
|
]
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate grouped conv CK Tile instances."
|
|
)
|
|
parser.add_argument(
|
|
"--filter_pattern",
|
|
type=str,
|
|
default="convolution",
|
|
help="Filter pattern for configs.",
|
|
)
|
|
parser.add_argument(
|
|
"--mode",
|
|
choices=["compilation", "tests", "profiler"],
|
|
type=str,
|
|
default="profiler",
|
|
help="Generator modes. compilation - empty instance list, tests - limited instance list, profiler - generate all instances",
|
|
)
|
|
parser.add_argument(
|
|
"--direction",
|
|
choices=["forward", "backward_weight", "backward_data", "all"],
|
|
type=str,
|
|
default="all",
|
|
help="Convolution direction for which to generate instances."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# apply empty filter
|
|
if args.mode == "compilation":
|
|
args.filter_pattern = "empty"
|
|
configs_prefix = "profiler"
|
|
elif args.mode == "tests":
|
|
configs_prefix = "tests"
|
|
elif args.mode == "profiler":
|
|
configs_prefix = "profiler"
|
|
else:
|
|
raise RuntimeError("wrong mode")
|
|
|
|
match args.direction:
|
|
case "forward":
|
|
process_direction(fwd_configs, args.direction, generate_instances_fwd, configs_prefix, args.filter_pattern)
|
|
case "backward_weight":
|
|
process_direction(bwd_weight_configs, args.direction, generate_instances_bwd_weight, configs_prefix, args.filter_pattern)
|
|
case "backward_data":
|
|
process_direction(bwd_data_configs, args.direction, generate_instances_bwd_data, configs_prefix, args.filter_pattern)
|
|
case "all":
|
|
process_direction(fwd_configs, "forward", generate_instances_fwd, configs_prefix, args.filter_pattern)
|
|
process_direction(bwd_weight_configs, "backward_weight", generate_instances_bwd_weight, configs_prefix, args.filter_pattern)
|
|
process_direction(bwd_data_configs, "backward_data", generate_instances_bwd_data, configs_prefix, args.filter_pattern)
|
|
|