# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT import argparse from pathlib import Path class ConvInstanceTemplateParams: def __init__( self, specialization, tile_size, warps, warp_tile, double_smem_buffer, num_wave_groups, pipeline_version, scheduler, scalar_per_vector, num_groups_to_merge, split_image, explicit_gemm, id, ): self.specialization = specialization self.tile_size = tile_size self.warps = warps self.warp_tile = warp_tile self.double_smem_buffer = double_smem_buffer self.num_wave_groups = num_wave_groups self.pipeline_version = pipeline_version self.scheduler = scheduler self.scalar_per_vector = scalar_per_vector self.num_groups_to_merge = num_groups_to_merge self.split_image = split_image self.explicit_gemm = explicit_gemm self.id = id def get_optimizations(self): explicit_gemm = "true" if self.explicit_gemm else "false" split_image = "true" if self.split_image else "false" num_groups_to_merge = str(self.num_groups_to_merge) return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}}}" def get_specialization(self): namespace = "ckb::TileConvSpecialization::" if self.specialization == "Default" or self.specialization == "OddC": return namespace + "DEFAULT" if self.specialization == "Filter1x1Pad0": return namespace + "FILTER_1X1_PAD0" if self.specialization == "Filter1x1Stride1Pad0": return namespace + "FILTER_1X1_STRIDE1_PAD0" if self.specialization == "Filter3x3": return namespace + "FILTER_3x3" else: raise RuntimeError("not supported specialization") def get_thread_block(self): return f"ckt::TileThreadBlock{{.tile_size = {{.m = {self.tile_size[0]}, .n = {self.tile_size[1]}, .k = {self.tile_size[2]}}}}}" def get_block_gemm_desc(self): double_smem_buffer = "true" if self.double_smem_buffer else "false" scheduler = ( "INTRAWAVE" if self.scheduler.find("Intrawave") != -1 else "INTERWAVE" ) return f"""ckt::TileBlockGemm{{ .warps = {{.m = {self.warps[0]}, .n = {self.warps[1]}, .k = {self.warps[2]}}}, .warp_tile = {{.m = {self.warp_tile[0]}, .n = {self.warp_tile[1]}, .k = {self.warp_tile[2]}}}, .double_smem_buffer = {double_smem_buffer}, .num_wave_groups = {self.num_wave_groups}, .pipeline_version = ckb::PipelineVersion::{self.pipeline_version}, .scheduler = ckb::PipelineScheduler::{scheduler}}}""" def get_block_transfer(self): return f"""ckt::TileTransfer{{.a_scalar_per_vector = {self.scalar_per_vector[0]}, .b_scalar_per_vector = {self.scalar_per_vector[1]}, .c_scalar_per_vector = {self.scalar_per_vector[2]}}}""" def get_dtype(problem_name): if problem_name.find("fp32") != -1: return "float" if problem_name.find("fp16") != -1: return "ck_tile::half_t" if problem_name.find("bf16") != -1: return "ck_tile::bf16_t" else: raise RuntimeError("Cannot parse data type from problem name: " + problem_name) def get_k_mfma(dtype, m_per_xdl, n_per_xdl): if m_per_xdl != n_per_xdl: raise RuntimeError("Not supported") if dtype == "float": if m_per_xdl == 32: return 2 else: return 4 else: if m_per_xdl == 32: return 8 else: return 16 def check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector): if a_scalar_per_vector != 1 and a_scalar_per_vector % 2 != 0: return False if b_scalar_per_vector != 1 and b_scalar_per_vector % 2 != 0: return False if c_scalar_per_vector != 1 and c_scalar_per_vector % 2 != 0: return False return True def generate_calls_inc(instances, problem_name, direction, filter_pattern): generate_dir = Path(__file__).resolve().parent output_dir = Path(f"{generate_dir}/instances/{direction}") output_dir.mkdir(parents=True, exist_ok=True) with open(f"{generate_dir}/instances/{direction}/{problem_name}_calls.inc", "w") as f: if problem_name.find(filter_pattern) == -1: return for instance in instances: instance_name = problem_name + "_" + str(instance.id) f.write(f"run_alg(run_{instance_name});\n") def generate_defs_inc(instances, problem_name, signature, direction, filter_pattern): generate_dir = Path(__file__).resolve().parent with open(f"{generate_dir}/instances/{direction}/{problem_name}.inc", "w") as f: if problem_name.find(filter_pattern) == -1: return for instance in instances: instance_name = problem_name + "_" + str(instance.id) f.write( f"std::tuple run_{instance_name}(\n" f" const ckt::Args<{signature}>& args,\n" f" const ckt::Inputs<{signature}>& inputs,\n" f" const ckt::Outputs<{signature}>& outputs,\n" f" const ck_tile::stream_config& s_conf);\n" ) def generate_conv_cpp( instances, problem_name, config, direction, signature_name, filter_pattern): for instance in instances: if problem_name.find(filter_pattern) == -1: break instance_name = problem_name + "_" + str(instance.id) generate_dir = Path(__file__).resolve().parent directory_path = Path(f"{generate_dir}/instances/{direction}/{config}") directory_path.mkdir(parents=True, exist_ok=True) template_file = "grouped_convolution_tile.cpp.in" with open(f"{generate_dir}/instances/{template_file}", "r",) as f: content = f.read() content = content.replace("gen_signature", signature_name) content = content.replace("gen_instance_name", instance_name) content = content.replace("gen_specialization", instance.get_specialization()) content = content.replace("gen_thread_block", instance.get_thread_block()) content = content.replace("gen_block_gemm_desc", instance.get_block_gemm_desc()) content = content.replace("gen_block_transfer", instance.get_block_transfer()) content = content.replace("gen_optimizations", instance.get_optimizations()) with open(f"{generate_dir}/instances/{direction}/{config}/{instance_name}.cpp","w",) as f: f.write(content) def parse_fwd_instances(instances, problem_name): convs = [] for instance_id, instance in enumerate(instances): if instance.find("#") != -1 or instance.find(";") != -1: continue instance_args_list = instance[instance.find("<") + 1 : instance.find(">")] args = instance_args_list.split(", ") block_size = int(args[0]) m_per_block = int(args[1]) n_per_block = int(args[2]) k_per_block = int(args[3]) spec = args[4] m_per_xdl = int(args[5]) n_per_xdl = int(args[6]) m_xdl_per_wave = int(args[7]) n_xdl_per_wave = int(args[8]) a_scalar_per_vector = int(args[9]) b_scalar_per_vector = int(args[10]) c_scalar_per_vector = int(args[11]) if len(args) == 15: num_groups_to_merge = int(args[14]) elif len(args) != 16 and len(args) != 14: raise RuntimeError("wrong number of parameters") else: num_groups_to_merge = 1 split_image = instance.find("Large") != -1 double_smem_buffer = instance.find("BlkGemmPipelineVersion: v4") != -1 num_wave_groups = 1 scheduler = ( "Intrawave" if instance.find("BlkGemmPipelineScheduler") == -1 else args[14] ) pipeline_version = ( "v1" if instance.find("BlkGemmPipelineVersion") == -1 else args[15] ) # Replace pipeline if Direct Load if instance.find("DirectLoad") != -1: if instance.find("BlkGemmPipelineVersion: v1") != -1: pipeline_version = "ASYNC_V1" elif instance.find("BlkGemmPipelineVersion: v4") != -1: pipeline_version = "ASYNC_V4" else: raise RuntimeError("not supported pipeline for direct load") else: pipeline_version = f"""V{pipeline_version[-1:]}""" m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave)) n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave)) warp_size = 64 k_warp = int(block_size / (warp_size * m_warp * n_warp)) dtype = get_dtype(problem_name) # TODO: Make it more flexible # k_per_xdl = f"ck_tile::get_k_warp_tile<{dtype}, {m_per_xdl}>()" if dtype == "float": if m_per_xdl == 32: if instance.find("BlkGemmPipelineVersion") == -1: k_per_xdl = 4 else: # Increase for universal gemm k_per_xdl = 8 else: k_per_xdl = 8 else: if m_per_xdl == 32: k_per_xdl = 16 else: k_per_xdl = 32 k_per_xdl = min(k_per_xdl, k_per_block) conv = ConvInstanceTemplateParams( spec, [m_per_block, n_per_block, k_per_block], [m_warp, n_warp, k_warp], [m_per_xdl, n_per_xdl, k_per_xdl], double_smem_buffer, num_wave_groups, pipeline_version, scheduler, [a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector], num_groups_to_merge, split_image, False, instance_id, ) convs.append(conv) return convs def parse_instance_string(instance_string): """Parse instance string, treating Seq(...) as a single parameter.""" params = [] current_param = "" paren_depth = 0 for char in instance_string: if char == '(': paren_depth += 1 current_param += char elif char == ')': paren_depth -= 1 current_param += char elif char == ',' and paren_depth == 0: # Only split on comma if we're not inside parentheses params.append(current_param.strip()) current_param = "" else: current_param += char # Add the last parameter if current_param.strip(): params.append(current_param.strip()) return params def parse_bwd_weight_instances(instances, problem_name): convs = [] for instance_id, instance in enumerate(instances): if instance.find("#") != -1 or instance.find(";") != -1: continue device_op_name = instance.split("<")[0] start = instance.index('<') + 1 end = instance.rindex('>') params_str = instance[start:end] args = parse_instance_string(params_str) is_v3_instance = instance.find("Xdl_CShuffleV3") != -1 is_two_stage_instance = instance.find("TwoStage") != -1 is_explicit_gemm = device_op_name.find("Explicit") != -1 if is_explicit_gemm: gemm_params = device_op_name = instance.split("<")[2].split(">")[1].split(",") args = [param.split(":")[1].strip() for param in gemm_params] spec = "Default" block_size = int(args[0]) mnk_per_block = args[1].split("x") m_per_block = int(mnk_per_block[0]) n_per_block = int(mnk_per_block[1]) k_per_block = int(mnk_per_block[2]) wave_tile = args[2].split("x") m_per_xdl = int(wave_tile[0]) n_per_xdl = int(wave_tile[1]) k1_values = args[3].split("x") ak1 = int(k1_values[0]) bk1 = int(k1_values[1]) k1 = min(ak1, bk1) wave_map = args[4].split("x") m_xdl_per_wave = int(wave_map[0]) n_xdl_per_wave = int(wave_map[1]) vector_read = args[5].split("x") a_scalar_per_vector = int(vector_read[0]) b_scalar_per_vector = int(vector_read[1]) c_scalar_per_vector_seq = [int(x) for x in vector_read[2].strip("Seq").strip("(").strip(")").split(",")] if len(set(c_scalar_per_vector_seq)) != 1: raise RuntimeError(f"c_scalar_per_vector must be the same across all waves for instance {instance_id} with device op {device_op_name}. Found values: {c_scalar_per_vector_seq}") c_scalar_per_vector = c_scalar_per_vector_seq[0] num_groups_to_merge = 1 # Block GEMM pipeline parameters blk_gemm_pipeline_schduler = args[6] blk_gemm_pipeline_version = args[7] else: spec = args[11] block_size = int(args[12]) m_per_block = int(args[13]) n_per_block = int(args[14]) k1 = int(args[16]) m_per_xdl = int(args[17]) n_per_xdl = int(args[18]) m_xdl_per_wave = int(args[19]) n_xdl_per_wave = int(args[20]) a_scalar_per_vector = int(args[25]) b_scalar_per_vector = int(args[32]) c_scalar_per_vector = int(args[38]) if is_v3_instance or is_two_stage_instance: k_per_block = int(args[15]) else: k0_per_block = int(args[15]) k_per_block = k0_per_block * k1 if is_v3_instance: if len(args) != 45: raise RuntimeError(f"Wrong number of parameters in the V3 XDL CShuffle instance string: {instance}") num_groups_to_merge = int(args[44]) # Block GEMM pipeline parameters blk_gemm_pipeline_schduler = args[39] blk_gemm_pipeline_version = args[40] elif is_two_stage_instance: print(f"Skipping instance {instance_id} with device op {device_op_name} since it's not supported yet.") continue else: # Regular V1 XDL CShuffle instance if len(args) != 43: raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}") num_groups_to_merge = 1 # Block GEMM pipeline parameters blk_gemm_pipeline_schduler = "Intrawave" blk_gemm_pipeline_version = "v1" # Common part to all solvers. # Sanity check for Block GEMM pipeline parameters # Scheduler must be either Intrawave or Interwave. # Version must be from v1 to v5 if blk_gemm_pipeline_schduler not in ["Intrawave", "Interwave"]: raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {blk_gemm_pipeline_schduler} in instance: {instance}") if blk_gemm_pipeline_version not in ["v1", "v2", "v3", "v4", "v5"]: raise RuntimeError(f"Invalid Block GEMM pipeline version: {blk_gemm_pipeline_version} in instance: {instance}") split_image = instance.find("Large") != -1 double_smem_buffer = blk_gemm_pipeline_version == "v4" num_wave_groups = 1 scheduler = blk_gemm_pipeline_schduler pipeline_version = blk_gemm_pipeline_version.upper() # OLd CK pipeline version V5 maps to V6 for CK Tile if pipeline_version == "V5": pipeline_version = "V6" m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave)) n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave)) warp_size = 64 k_warp = int(block_size / (warp_size * m_warp * n_warp)) dtype = get_dtype(problem_name) k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)) if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False: print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.") continue conv = ConvInstanceTemplateParams( spec, [m_per_block, n_per_block, k_per_block], [m_warp, n_warp, k_warp], [m_per_xdl, n_per_xdl, k_per_xdl], double_smem_buffer, num_wave_groups, pipeline_version, scheduler, [a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector], num_groups_to_merge, split_image, is_explicit_gemm, instance_id, ) convs.append(conv) return convs def parse_bwd_data_instances(instances, problem_name): convs = [] print("Parsing backward data instances is not supported yet, skipping all instances.") # TODO: Implement parsing logic for backward data instances. return convs def generate_instances_fwd(instances, problem_name, config, filter_pattern): direction = "forward" signature_name = f"SIGNATURE_{config.upper()}_FWD" instances = parse_fwd_instances(instances, problem_name) generate_calls_inc(instances, problem_name, direction, filter_pattern) generate_defs_inc( instances, problem_name, signature_name, direction, filter_pattern, ) generate_conv_cpp( instances, problem_name, config, direction, signature_name, filter_pattern ) def generate_instances_bwd_weight(instances, problem_name, config, filter_pattern): direction = "backward_weight" signature_name = f"SIGNATURE_{config.upper()}_BWD_WEIGHT" instances = parse_bwd_weight_instances(instances, problem_name) generate_calls_inc(instances, problem_name, direction, filter_pattern) generate_defs_inc( instances, problem_name, signature_name, direction, filter_pattern, ) generate_conv_cpp( instances, problem_name, config, direction, signature_name, filter_pattern ) def generate_instances_bwd_data(instances, problem_name, config, filter_pattern): direction = "backward_data" signature_name = f"SIGNATURE_{config.upper()}_BWD_DATA" instances = parse_bwd_data_instances(instances, problem_name) generate_calls_inc(instances, problem_name, direction, filter_pattern) generate_defs_inc( instances, problem_name, signature_name, direction, filter_pattern, ) generate_conv_cpp( instances, problem_name, config, direction, signature_name, filter_pattern ) def process_direction(configs, direction, generate_func, configs_prefix, filter_pattern): """Helper function to process a single direction.""" for config in configs: instances = [] generate_dir = Path(__file__).resolve().parent config_path = f"{generate_dir}/configs/{direction}/{configs_prefix}/{config}.conf" with open(config_path, "r") as file: instances = file.readlines() # Determine problem name based on direction if direction == "forward": problem_name = f"grouped_convolution_forward_tile_{config}" elif direction == "backward_weight": problem_name = f"grouped_convolution_backward_weight_tile_{config}" elif direction == "backward_data": problem_name = f"grouped_convolution_backward_data_tile_{config}" else: raise RuntimeError(f"Unknown direction: {direction}") generate_func(instances, problem_name, config, filter_pattern) if __name__ == "__main__": fwd_configs = [ "nhwgc_fp32", "nhwgc_fp16", "nhwgc_bf16", "ndhwgc_fp32", "ndhwgc_fp16", "ndhwgc_bf16", ] # FP32 doesn't work for bwd weigth currently bwd_weight_configs = [ "nhwgc_fp32", "nhwgc_fp16", "nhwgc_bf16", "ndhwgc_fp32", "ndhwgc_fp16", "ndhwgc_bf16", ] bwd_data_configs = [ "nhwgc_fp32", "nhwgc_fp16", "nhwgc_bf16", "ndhwgc_fp32", "ndhwgc_fp16", "ndhwgc_bf16", ] parser = argparse.ArgumentParser( description="Generate grouped conv CK Tile instances." ) parser.add_argument( "--filter_pattern", type=str, default="convolution", help="Filter pattern for configs.", ) parser.add_argument( "--mode", choices=["compilation", "tests", "profiler"], type=str, default="profiler", help="Generator modes. compilation - empty instance list, tests - limited instance list, profiler - generate all instances", ) parser.add_argument( "--direction", choices=["forward", "backward_weight", "backward_data", "all"], type=str, default="all", help="Convolution direction for which to generate instances." ) args = parser.parse_args() # apply empty filter if args.mode == "compilation": args.filter_pattern = "empty" configs_prefix = "profiler" elif args.mode == "tests": configs_prefix = "tests" elif args.mode == "profiler": configs_prefix = "profiler" else: raise RuntimeError("wrong mode") match args.direction: case "forward": process_direction(fwd_configs, args.direction, generate_instances_fwd, configs_prefix, args.filter_pattern) case "backward_weight": process_direction(bwd_weight_configs, args.direction, generate_instances_bwd_weight, configs_prefix, args.filter_pattern) case "backward_data": process_direction(bwd_data_configs, args.direction, generate_instances_bwd_data, configs_prefix, args.filter_pattern) case "all": process_direction(fwd_configs, "forward", generate_instances_fwd, configs_prefix, args.filter_pattern) process_direction(bwd_weight_configs, "backward_weight", generate_instances_bwd_weight, configs_prefix, args.filter_pattern) process_direction(bwd_data_configs, "backward_data", generate_instances_bwd_data, configs_prefix, args.filter_pattern)