# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT import argparse import shutil from pathlib import Path class ConvInstanceTemplateParams: def __init__( self, specialization, tile_size, warps, warp_tile, double_smem_buffer, num_wave_groups, is_two_stage_instance, pipeline_version, scheduler, scalar_per_vector, num_groups_to_merge, split_image, explicit_gemm, id, ): self.specialization = specialization self.tile_size = tile_size self.warps = warps self.warp_tile = warp_tile self.double_smem_buffer = double_smem_buffer self.num_wave_groups = num_wave_groups self.is_two_stage_instance = is_two_stage_instance self.pipeline_version = pipeline_version self.scheduler = scheduler self.scalar_per_vector = scalar_per_vector self.num_groups_to_merge = num_groups_to_merge self.split_image = split_image self.explicit_gemm = explicit_gemm self.id = id def get_optimizations(self): explicit_gemm = "true" if self.explicit_gemm else "false" split_image = "true" if self.split_image else "false" num_groups_to_merge = str(self.num_groups_to_merge) two_stage_instance = "true" if self.is_two_stage_instance else "false" return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}, .two_stage = {two_stage_instance}}}" def get_specialization(self): namespace = "ckb::TileConvSpecialization::" if self.specialization == "Default" or self.specialization == "OddC": return namespace + "DEFAULT" if self.specialization == "Filter1x1Pad0": return namespace + "FILTER_1X1_PAD0" if self.specialization == "Filter1x1Stride1Pad0": return namespace + "FILTER_1X1_STRIDE1_PAD0" if self.specialization == "Filter3x3": return namespace + "FILTER_3x3" else: raise RuntimeError("not supported specialization") def get_thread_block(self): return f"ckt::TileThreadBlock{{.tile_size = {{.m = {self.tile_size[0]}, .n = {self.tile_size[1]}, .k = {self.tile_size[2]}}}}}" def get_block_gemm_desc(self): double_smem_buffer = "true" if self.double_smem_buffer else "false" scheduler = ( "INTRAWAVE" if self.scheduler.find("Intrawave") != -1 else "INTERWAVE" ) return f"""ckt::TileBlockGemm{{ .warps = {{.m = {self.warps[0]}, .n = {self.warps[1]}, .k = {self.warps[2]}}}, .warp_tile = {{.m = {self.warp_tile[0]}, .n = {self.warp_tile[1]}, .k = {self.warp_tile[2]}}}, .double_smem_buffer = {double_smem_buffer}, .num_wave_groups = {self.num_wave_groups}, .pipeline_version = ckb::PipelineVersion::{self.pipeline_version}, .scheduler = ckb::PipelineScheduler::{scheduler}}}""" def get_block_transfer(self): return f"""ckt::TileTransfer{{.a_scalar_per_vector = {self.scalar_per_vector[0]}, .b_scalar_per_vector = {self.scalar_per_vector[1]}, .c_scalar_per_vector = {self.scalar_per_vector[2]}}}""" def get_dtype(problem_name): if problem_name.find("fp32") != -1: return "float" if problem_name.find("fp16") != -1: return "ck_tile::half_t" if problem_name.find("bf16") != -1: return "ck_tile::bf16_t" else: raise RuntimeError("Cannot parse data type from problem name: " + problem_name) def get_k_mfma(dtype, m_per_xdl, n_per_xdl): if m_per_xdl != n_per_xdl: raise RuntimeError("Not supported") if dtype == "float": if m_per_xdl == 32: return 2 else: return 4 else: if m_per_xdl == 32: return 8 else: return 16 def check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector): if a_scalar_per_vector != 1 and a_scalar_per_vector % 2 != 0: return False if b_scalar_per_vector != 1 and b_scalar_per_vector % 2 != 0: return False if c_scalar_per_vector != 1 and c_scalar_per_vector % 2 != 0: return False return True def parse_instance_string(instance_string): """Parse instance string, treating Seq(...) as a single parameter.""" params = [] current_param = "" paren_depth = 0 for char in instance_string: if char == '(': paren_depth += 1 current_param += char elif char == ')': paren_depth -= 1 current_param += char elif char == ',' and paren_depth == 0: # Only split on comma if we're not inside parentheses params.append(current_param.strip()) current_param = "" else: current_param += char # Add the last parameter if current_param.strip(): params.append(current_param.strip()) return params def copy_includes(instances_path): inc_dir = Path(__file__).resolve().parent output_dir = Path(instances_path) output_dir.mkdir(parents=True, exist_ok=True) shutil.copy(f"{inc_dir}/include/instance_includes.inc", instances_path) shutil.copy(f"{inc_dir}/include/instance_run.inc", instances_path) shutil.copy(f"{inc_dir}/include/signatures.hpp", instances_path) def generate_calls_inc(instances, problem_name, direction, filter_pattern): generate_dir = Path(__file__).resolve().parent output_dir = Path(f"{generate_dir}/instances/{direction}") output_dir.mkdir(parents=True, exist_ok=True) with open(f"{generate_dir}/instances/{direction}/{problem_name}_calls.inc", "w") as f: if problem_name.find(filter_pattern) == -1: return for instance in instances: instance_name = problem_name + "_" + str(instance.id) f.write(f"run_alg(run_{instance_name});\n") def generate_defs_inc(instances, problem_name, signature, direction, filter_pattern): generate_dir = Path(__file__).resolve().parent with open(f"{generate_dir}/instances/{direction}/{problem_name}.inc", "w") as f: if problem_name.find(filter_pattern) == -1: return for instance in instances: instance_name = problem_name + "_" + str(instance.id) f.write( f"std::tuple run_{instance_name}(\n" f" const ckt::Args<{signature}>& args,\n" f" const ckt::Inputs<{signature}>& inputs,\n" f" const ckt::Outputs<{signature}>& outputs,\n" f" const ck_tile::stream_config& s_conf);\n" ) def generate_conv_cpp( instances, problem_name, config, direction, signature_name, filter_pattern, instances_path): for instance in instances: if problem_name.find(filter_pattern) == -1: break instance_name = problem_name + "_" + str(instance.id) directory_path = Path(f"{instances_path}/{direction}/{config}") directory_path.mkdir(parents=True, exist_ok=True) parent_dir = Path(__file__).resolve().parent template_file = "include/grouped_convolution_tile.cpp.in" with open(f"{parent_dir}/{template_file}", "r",) as f: content = f.read() content = content.replace("gen_signature", signature_name) content = content.replace("gen_instance_name", instance_name) content = content.replace("gen_specialization", instance.get_specialization()) content = content.replace("gen_thread_block", instance.get_thread_block()) content = content.replace("gen_block_gemm_desc", instance.get_block_gemm_desc()) content = content.replace("gen_block_transfer", instance.get_block_transfer()) content = content.replace("gen_optimizations", instance.get_optimizations()) with open(f"{instances_path}/{direction}/{config}/{instance_name}.cpp","w",) as f: f.write(content) def parse_fwd_instances(instances, problem_name): convs = [] for instance_id, instance in enumerate(instances): if instance.find("#") != -1 or instance.find(";") != -1: continue start = instance.index('<') + 1 end = instance.rindex('>') params_str = instance[start:end] args = parse_instance_string(params_str) is_v3_instance = instance.find("Xdl_CShuffle_V3") != -1 split_image = instance.find("Large_Tensor") != -1 if is_v3_instance: spec = args[14] block_size = int(args[16]) m_per_block = int(args[17]) n_per_block = int(args[18]) k_per_block = int(args[19]) k1 = int(args[20]) m_per_xdl = int(args[22]) n_per_xdl = int(args[23]) m_xdl_per_wave = int(args[24]) n_xdl_per_wave = int(args[25]) a_scalar_per_vector = int(args[30]) b_scalar_per_vector = int(args[37]) c_scalar_per_vector = int(args[43]) scheduler = args[44] pipeline_version = args[45] direct_load = args[48] == "true" num_groups_to_merge = int(args[49]) else: spec = args[14] block_size = int(args[17]) m_per_block = int(args[18]) n_per_block = int(args[19]) k_per_block = int(args[20]) k1 = int(args[21]) m_per_xdl = int(args[23]) n_per_xdl = int(args[24]) m_xdl_per_wave = int(args[25]) n_xdl_per_wave = int(args[26]) a_scalar_per_vector = int(args[31]) b_scalar_per_vector = int(args[38]) c_scalar_per_vector = int(args[44]) scheduler = "Intrawave" pipeline_version = "v1" direct_load = 0 num_groups_to_merge = 0 if split_image else int(args[48]) double_smem_buffer = pipeline_version == "v4" num_wave_groups = 1 # Replace pipeline if Direct Load if direct_load: if pipeline_version == "v1": pipeline_version = "ASYNC_V1" elif pipeline_version == "v4": pipeline_version = "ASYNC_V4" else: raise RuntimeError(f"{pipeline_version} not supported pipeline for direct load") else: pipeline_version = pipeline_version.upper() m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave)) n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave)) warp_size = 64 k_warp = int(block_size / (warp_size * m_warp * n_warp)) dtype = get_dtype(problem_name) k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)) if split_image: print(f"Skipping instance {instance_id} with split_image since it's not supported yet.") continue if pipeline_version == "V5": print(f"Skipping instance {instance_id} with V5 since it's not supported yet.") continue if pipeline_version == "ASYNC_V4": print(f"Skipping instance {instance_id} with ASYNC_V4 since it's not supported yet.") continue is_two_stage = False conv = ConvInstanceTemplateParams( spec, [m_per_block, n_per_block, k_per_block], [m_warp, n_warp, k_warp], [m_per_xdl, n_per_xdl, k_per_xdl], double_smem_buffer, num_wave_groups, is_two_stage, pipeline_version, scheduler, [a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector], num_groups_to_merge, split_image, False, instance_id, ) convs.append(conv) return convs def parse_bwd_weight_instances(instances, problem_name): convs = [] for instance_id, instance in enumerate(instances): if instance.find("#") != -1 or instance.find(";") != -1: continue device_op_name = instance.split("<")[0] start = instance.index('<') + 1 end = instance.rindex('>') params_str = instance[start:end] args = parse_instance_string(params_str) is_v3_instance = instance.find("Xdl_CShuffleV3") != -1 is_two_stage_instance = instance.find("TwoStage") != -1 is_explicit_gemm = device_op_name.find("Explicit") != -1 if is_explicit_gemm: gemm_params = device_op_name = instance.split("<")[2].split(">")[1].split(",") args = [param.split(":")[1].strip() for param in gemm_params] spec = "Default" block_size = int(args[0]) mnk_per_block = args[1].split("x") m_per_block = int(mnk_per_block[0]) n_per_block = int(mnk_per_block[1]) k_per_block = int(mnk_per_block[2]) wave_tile = args[2].split("x") m_per_xdl = int(wave_tile[0]) n_per_xdl = int(wave_tile[1]) k1_values = args[3].split("x") ak1 = int(k1_values[0]) bk1 = int(k1_values[1]) k1 = min(ak1, bk1) wave_map = args[4].split("x") m_xdl_per_wave = int(wave_map[0]) n_xdl_per_wave = int(wave_map[1]) vector_read = args[5].split("x") a_scalar_per_vector = int(vector_read[0]) b_scalar_per_vector = int(vector_read[1]) c_scalar_per_vector_seq = [int(x) for x in vector_read[2].strip("Seq").strip("(").strip(")").split(",")] if len(set(c_scalar_per_vector_seq)) != 1: raise RuntimeError(f"c_scalar_per_vector must be the same across all waves for instance {instance_id} with device op {device_op_name}. Found values: {c_scalar_per_vector_seq}") c_scalar_per_vector = c_scalar_per_vector_seq[0] num_groups_to_merge = 1 # Block GEMM pipeline parameters block_gemm_pipeline_scheduler = args[6] blk_gemm_pipeline_version = args[7] else: spec = args[11] block_size = int(args[12]) m_per_block = int(args[13]) n_per_block = int(args[14]) k1 = int(args[16]) m_per_xdl = int(args[17]) n_per_xdl = int(args[18]) m_xdl_per_wave = int(args[19]) n_xdl_per_wave = int(args[20]) a_scalar_per_vector = int(args[25]) b_scalar_per_vector = int(args[32]) c_scalar_per_vector = int(args[38]) if is_v3_instance or is_two_stage_instance: k_per_block = int(args[15]) else: k0_per_block = int(args[15]) k_per_block = k0_per_block * k1 if is_v3_instance: if len(args) != 45: raise RuntimeError(f"Wrong number of parameters in the V3 XDL CShuffle instance string: {instance}") num_groups_to_merge = int(args[44]) # Block GEMM pipeline parameters block_gemm_pipeline_scheduler = args[39] blk_gemm_pipeline_version = args[40] elif is_two_stage_instance: if len(args) != 46: raise RuntimeError(f"Wrong number of parameters in the TwoStage instance string: {instance}\n" + f"Expected 46 parameters for TwoStage instance. Found {len(args)} parameters.") num_groups_to_merge = args[41] # Block GEMM pipeline parameters block_gemm_pipeline_scheduler = args[39] blk_gemm_pipeline_version = args[40] else: # Regular V1 XDL CShuffle instance if len(args) != 43: raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}\n" + f"Expected 43 parameters for V1 instance. Found {len(args)} parameters.") num_groups_to_merge = 1 # Block GEMM pipeline parameters block_gemm_pipeline_scheduler = "Intrawave" blk_gemm_pipeline_version = "v1" # Common part to all solvers. # Sanity check for Block GEMM pipeline parameters # Scheduler must be either Intrawave or Interwave. # Version must be from v1 to v5 if block_gemm_pipeline_scheduler not in ["Intrawave", "Interwave"]: raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {block_gemm_pipeline_scheduler} in instance: {instance}") if blk_gemm_pipeline_version not in ["v1", "v2", "v3", "v4", "v5"]: raise RuntimeError(f"Invalid Block GEMM pipeline version: {blk_gemm_pipeline_version} in instance: {instance}") split_image = instance.find("Large") != -1 double_smem_buffer = blk_gemm_pipeline_version == "v4" num_wave_groups = 1 scheduler = block_gemm_pipeline_scheduler pipeline_version = blk_gemm_pipeline_version.upper() # OLd CK pipeline version V5 maps to V6 for CK Tile if pipeline_version == "V5": pipeline_version = "V6" m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave)) n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave)) warp_size = 64 k_warp = int(block_size / (warp_size * m_warp * n_warp)) dtype = get_dtype(problem_name) k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)) if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False: print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.") continue if pipeline_version == "V6": print(f"Skipping instance {instance_id} with V6 since it's not supported yet.") continue conv = ConvInstanceTemplateParams( spec, [m_per_block, n_per_block, k_per_block], [m_warp, n_warp, k_warp], [m_per_xdl, n_per_xdl, k_per_xdl], double_smem_buffer, num_wave_groups, is_two_stage_instance, pipeline_version, scheduler, [a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector], num_groups_to_merge, split_image, is_explicit_gemm, instance_id, ) convs.append(conv) return convs def parse_bwd_data_instances(instances, problem_name): convs = [] for instance_id, instance in enumerate(instances): if instance.find("#") != -1 or instance.find(";") != -1: continue start = instance.index('<') + 1 end = instance.rindex('>') params_str = instance[start:end] args = parse_instance_string(params_str) is_v1_instance = instance.find("Xdl_CShuffle<") != -1 if is_v1_instance: if len(args) != 51: raise RuntimeError(f"Wrong number of parameters in the V1 XDL CShuffle instance string: {instance}\n" + f"Expected 51 parameters for V1 instance. Found {len(args)} parameters.") else: raise RuntimeError(f"Only V1 XDL CShuffle instances are supported for backward data. Found instance: {instance}") spec = args[13] block_size = int(args[17]) m_per_block = int(args[18]) n_per_block = int(args[19]) k_per_block = int(args[20]) ak1 = int(args[21]) bk1 = int(args[22]) m_per_xdl = int(args[23]) n_per_xdl = int(args[24]) m_xdl_per_wave = int(args[25]) n_xdl_per_wave = int(args[26]) a_scalar_per_vector = int(args[31]) b_scalar_per_vector = int(args[38]) c_scalar_per_vector = int(args[44]) if ak1 != bk1: raise RuntimeError(f"Not supported instance {instance_id} since ak1 != bk1. ak1: {ak1}, bk1: {bk1} in instance: {instance}") k1 = min(ak1, bk1) # TODO: Do we need split image for 3D bwd data convs? split_image = False # Default optimization parameters num_groups_to_merge = 1 is_two_stage_instance = False is_explicit_gemm = False num_wave_groups = 1 direct_load = False # Block GEMM pipeline parameters block_gemm_pipeline_scheduler = args[46] if block_gemm_pipeline_scheduler == "Default": block_gemm_pipeline_scheduler = "Intrawave" blk_gemm_pipeline_version = "v1" if block_gemm_pipeline_scheduler == "Interwave": blk_gemm_pipeline_version = "v1" # Sanity check for Block GEMM pipeline parameters # Scheduler must be either Intrawave or Interwave. # Version must be from v1 to v5 if block_gemm_pipeline_scheduler not in ["Intrawave", "Interwave"]: raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {block_gemm_pipeline_scheduler} in instance: {instance}") if blk_gemm_pipeline_version not in ["v1", "v2", "v3", "v4", "v5"]: raise RuntimeError(f"Invalid Block GEMM pipeline version: {blk_gemm_pipeline_version} in instance: {instance}") double_smem_buffer = blk_gemm_pipeline_version == "v4" scheduler = block_gemm_pipeline_scheduler pipeline_version = blk_gemm_pipeline_version.upper() # Old CK pipeline version V5 maps to V6 for CK Tile if pipeline_version == "V5": pipeline_version = "V6" if direct_load: if pipeline_version == "V1": pipeline_version = "ASYNC_V1" elif pipeline_version == "V4": pipeline_version = "ASYNC_V4" else: raise RuntimeError( f"Not supported pipeline for direct load: pipeline_version={pipeline_version} in instance: {instance}" ) m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave)) n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave)) warp_size = 64 k_warp = int(block_size / (warp_size * m_warp * n_warp)) dtype = get_dtype(problem_name) k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)) if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False: print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.") continue if pipeline_version == "V6": print(f"Skipping instance {instance_id} with V6 since it's not supported yet.") continue # Check vector sizes for A and B tensors - we cannot oversubscribe. num_tile_elements_a = m_per_xdl * k_per_xdl num_tile_elements_b = n_per_xdl * k_per_xdl max_vector_size_a = max(1, num_tile_elements_a // block_size) max_vector_size_b = max(1, num_tile_elements_b // block_size) a_scalar_per_vector = min(a_scalar_per_vector, max_vector_size_a) b_scalar_per_vector = min(b_scalar_per_vector, max_vector_size_b) conv = ConvInstanceTemplateParams( spec, [m_per_block, n_per_block, k_per_block], [m_warp, n_warp, k_warp], [m_per_xdl, n_per_xdl, k_per_xdl], double_smem_buffer, num_wave_groups, is_two_stage_instance, pipeline_version, scheduler, [a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector], num_groups_to_merge, split_image, is_explicit_gemm, instance_id, ) convs.append(conv) return convs def generate_instances_fwd(instances, problem_name, config, filter_pattern, instances_path): direction = "forward" signature_name = f"SIGNATURE_{config.upper()}_FWD" instances = parse_fwd_instances(instances, problem_name) generate_calls_inc(instances, problem_name, direction, filter_pattern) generate_defs_inc( instances, problem_name, signature_name, direction, filter_pattern ) generate_conv_cpp( instances, problem_name, config, direction, signature_name, filter_pattern, instances_path ) def generate_instances_bwd_weight(instances, problem_name, config, filter_pattern, instances_path): direction = "backward_weight" signature_name = f"SIGNATURE_{config.upper()}_BWD_WEIGHT" instances = parse_bwd_weight_instances(instances, problem_name) generate_calls_inc(instances, problem_name, direction, filter_pattern) generate_defs_inc( instances, problem_name, signature_name, direction, filter_pattern ) generate_conv_cpp( instances, problem_name, config, direction, signature_name, filter_pattern, instances_path ) def generate_instances_bwd_data(instances, problem_name, config, filter_pattern, instances_path): direction = "backward_data" signature_name = f"SIGNATURE_{config.upper()}_BWD_DATA" instances = parse_bwd_data_instances(instances, problem_name) generate_calls_inc(instances, problem_name, direction, filter_pattern) generate_defs_inc( instances, problem_name, signature_name, direction, filter_pattern ) generate_conv_cpp( instances, problem_name, config, direction, signature_name, filter_pattern, instances_path ) def process_direction(configs, direction, generate_func, configs_prefix, filter_pattern, instances_path): """Helper function to process a single direction.""" for config in configs: instances = [] generate_dir = Path(__file__).resolve().parent config_path = f"{generate_dir}/configs/{direction}/{configs_prefix}/{config}.conf" with open(config_path, "r") as file: instances = file.readlines() # Determine problem name based on direction if direction == "forward": problem_name = f"grouped_convolution_forward_tile_{config}" elif direction == "backward_weight": problem_name = f"grouped_convolution_backward_weight_tile_{config}" elif direction == "backward_data": problem_name = f"grouped_convolution_backward_data_tile_{config}" else: raise RuntimeError(f"Unknown direction: {direction}") generate_func(instances, problem_name, config, filter_pattern, instances_path) if __name__ == "__main__": fwd_configs = [ "nhwgc_fp32", "nhwgc_fp16", "nhwgc_bf16", "ndhwgc_fp32", "ndhwgc_fp16", "ndhwgc_bf16", ] # FP32 doesn't work for bwd weigth currently bwd_weight_configs = [ "nhwgc_fp32", "nhwgc_fp16", "nhwgc_bf16", "ndhwgc_fp32", "ndhwgc_fp16", "ndhwgc_bf16", ] bwd_data_configs = [ "nhwgc_fp32", "nhwgc_fp16", "nhwgc_bf16", "ndhwgc_fp32", "ndhwgc_fp16", "ndhwgc_bf16", ] parser = argparse.ArgumentParser( description="Generate grouped conv CK Tile instances." ) parser.add_argument( "--filter_pattern", type=str, default="convolution", help="Filter pattern for configs.", ) parser.add_argument( "--mode", choices=["compilation", "tests", "profiler"], type=str, default="profiler", help="Generator modes. compilation - empty instance list, tests - limited instance list, profiler - generate all instances", ) parser.add_argument( "--direction", choices=["forward", "backward_weight", "backward_data", "all"], type=str, default="all", help="Convolution direction for which to generate instances." ) parser.add_argument( "--instances_dir", type=str, default="../build/experimental/grouped_convolution_tile_instances", help="Directory store generated instances." ) args = parser.parse_args() # apply empty filter if args.mode == "compilation": args.filter_pattern = "empty" configs_prefix = "profiler" elif args.mode == "tests": configs_prefix = "tests" elif args.mode == "profiler": configs_prefix = "profiler" else: raise RuntimeError("wrong mode") copy_includes(args.instances_dir) match args.direction: case "forward": process_direction(fwd_configs, args.direction, generate_instances_fwd, configs_prefix, args.filter_pattern, args.instances_dir) case "backward_weight": process_direction(bwd_weight_configs, args.direction, generate_instances_bwd_weight, configs_prefix, args.filter_pattern, args.instances_dir) case "backward_data": process_direction(bwd_data_configs, args.direction, generate_instances_bwd_data, configs_prefix, args.filter_pattern, args.instances_dir) case "all": process_direction(fwd_configs, "forward", generate_instances_fwd, configs_prefix, args.filter_pattern, args.instances_dir) process_direction(bwd_weight_configs, "backward_weight", generate_instances_bwd_weight, configs_prefix, args.filter_pattern, args.instances_dir) process_direction(bwd_data_configs, "backward_data", generate_instances_bwd_data, configs_prefix, args.filter_pattern, args.instances_dir)