diff --git a/python/ck4inductor/batched_universal_gemm/gen_instances.py b/python/ck4inductor/batched_universal_gemm/gen_instances.py new file mode 100644 index 0000000000..8879fb93db --- /dev/null +++ b/python/ck4inductor/batched_universal_gemm/gen_instances.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +import logging +import os +import subprocess +from dataclasses import replace +from functools import lru_cache +from typing import List + +from ..util import library_path + +from .op import CKBatchedGemmOperation + +log = logging.getLogger(__name__) + + +def _ck_library_dir(): + gemm_instances_path = os.path.join( + library_path(), + "src", + "tensor_operation_instance", + "gpu", + "gemm_universal_batched", + ) + if not os.path.exists(gemm_instances_path): + log.error("CK library path %s does not exist", gemm_instances_path) + return None + return gemm_instances_path + + +def parse_instances(str_instances: List[str]) -> List[CKBatchedGemmOperation]: + """ + Parse the lines containing Universal Gemm template instances into `CKBatchedGemmOperation` instances + """ + + def maybe_int(s): + try: + return int(s) + except ValueError: + return s + + op_instances = [] + for line in str_instances: + s_template_args = line.split("DeviceBatchedGemmMultiD_Xdl_CShuffle_V3")[ + -1 + ].strip("<>, ") + template_args = [] + i_current = 0 + while i_current < len(s_template_args): + if s_template_args[i_current] == " ": + # skip whitespace + i_current += 1 + continue + elif s_template_args[i_current : i_current + 2] == "S<": + # parse template S + i_next = s_template_args.find(">", i_current) + template_args.append( + tuple(map(int, s_template_args[i_current + 2 : i_next].split(","))) + ) + i_current = i_next + 2 + else: + # all string attributes must be either type aliases or global constants in C++ + i_next = s_template_args.find(",", i_current) + template_args.append( + maybe_int( + s_template_args[i_current : i_next if i_next != -1 else None] + ) + ) + if i_next != -1: + i_current = i_next + 1 + if i_next == -1: + break + + # ds layout and dtype are parsed as placeholder; reset value + template_args[2] = tuple() # ds layout + template_args[6] = tuple() # ds dtype + + new_instance = CKBatchedGemmOperation( + *template_args, # type: ignore[arg-type] + ) + + op_instances.append(new_instance) + return op_instances + + +@lru_cache(None) +def gen_ops_library() -> List[CKBatchedGemmOperation]: + """ + Parse the Universal Gemm instances defined in the composable kernel library folder. + """ + ck_library_dir = _ck_library_dir() + if not ck_library_dir: + return [] + + grep_result = subprocess.run( + [ + "grep", + "-inR", + "DeviceBatchedGemmMultiD_Xdl_CShuffle_V3", + _ck_library_dir(), + ], + capture_output=True, + text=True, + ) + + op_instances = parse_instances(grep_result.stdout.strip().split("\n")) + + log.debug("ck instances from library: %d", len(op_instances)) + + schedulers = [ + "BlockGemmPipelineScheduler::Intrawave", + "BlockGemmPipelineScheduler::Interwave", + ] + gemm_specs = [ + "GemmSpecialization::Default", + "GemmSpecialization::MPadding", + "GemmSpecialization::NPadding", + "GemmSpecialization::KPadding", + "GemmSpecialization::MNPadding", + "GemmSpecialization::MKPadding", + "GemmSpecialization::NKPadding", + "GemmSpecialization::MNKPadding", + ] + + # substitute templated args by looping through their domains + substitute_instances = [] + for instance in op_instances: + sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched" + sub_spec = instance.gemm_specialization == "GemmSpec" + schedulers_range = ( + schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler] + ) + spec_range = gemm_specs if sub_spec else [instance.gemm_specialization] + for scheduler in schedulers_range: + for spec in spec_range: + substitute_instances.append( + replace( + instance, + block_gemm_pipeline_scheduler=scheduler, + gemm_specialization=spec, + ) + ) + + return substitute_instances + + +if __name__ == "__main__": + print(gen_ops_library()) diff --git a/python/ck4inductor/batched_universal_gemm/op.py b/python/ck4inductor/batched_universal_gemm/op.py new file mode 100644 index 0000000000..96978ac8d2 --- /dev/null +++ b/python/ck4inductor/batched_universal_gemm/op.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +from dataclasses import asdict, dataclass +from typing import Optional, Tuple + + +@dataclass +class CKBatchedGemmOperation: + """ + A python dataclass storing the template parameters of a CK Universal Gemm template instance + """ + + a_layout: str + b_layout: str + ds_layouts: Tuple[str] # addmm specific + c_layout: str + + a_element_dtype: str + b_element_dtype: str + ds_element_dtypes: Tuple[str] # addmm specific + c_element_dtype: str + + acc_dtype: str + c_shuffle_dtype: str + + a_elementwise_op: str + b_elementwise_op: str + c_elementwise_op: str + + gemm_specialization: str + + block_size: int + + m_per_block: int + n_per_block: int + k_per_block: int + + a_k1: int + b_k1: int + + m_per_xdl: int + n_per_xdl: int + + m_xdl_per_wave: int + n_xdl_per_wave: int + + a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int] + a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + a_block_transfer_src_access_order: Tuple[int, int, int] + a_block_transfer_src_vector_dim: int + a_block_transfer_src_scalar_per_vector: int + a_block_transfer_dst_scalar_per_vector_ak1: int + a_block_lds_extra_m: bool + + b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int] + b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + b_block_transfer_src_access_order: Tuple[int, int, int] + + b_block_transfer_src_vector_dim: int + b_block_transfer_src_scalar_per_vector: int + b_block_transfer_dst_scalar_per_vector_bk1: int + b_block_lds_extra_n: bool + + c_shuffle_m_xdl_per_wave_per_shuffle: int + c_shuffle_n_xdl_per_wave_per_shuffle: int + + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: ( + Tuple[int, int, int, int] + ) + c_shuffle_block_transfer_scalar_per_vector_n_per_block: Tuple[int] + block_gemm_pipeline_scheduler: str + block_gemm_pipeline_version: str + + a_compute_dtype: Optional[str] = None + b_compute_dtype: Optional[str] = None + + def name(self): + # cpp alias for template instance + return f"ck_device_batched_gemm_multi_d_xdl_c_shuffle_v3_{self.key_name()}" + + def key_name(self): + # TBD; must be unique per instance. Intended to use as dict key + return "_".join( + [ + "K" + + field_name.replace("_", "").lower() + + "V" + + ( + "x".join(map(str, iter(field_value))) + if isinstance(field_value, tuple) + else str(field_value).replace(":", "") + ) + for field_name, field_value in self.dict_items() + ] + ) + + def dict_items(self): + return asdict(self).items() diff --git a/python/ck4inductor/grouped_conv_fwd/gen_instances.py b/python/ck4inductor/grouped_conv_fwd/gen_instances.py index ffbea6bdc7..feca20a3b8 100644 --- a/python/ck4inductor/grouped_conv_fwd/gen_instances.py +++ b/python/ck4inductor/grouped_conv_fwd/gen_instances.py @@ -130,9 +130,7 @@ def gen_conv_ops_library() -> List[CKGroupedConvFwdOp]: # substitute templated args by looping through their domains substitute_instances = [] for instance in op_instances: - sub_scheduler = ( - instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched" - ) + sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched" sub_spec = instance.conv_forward_specialization == "ConvSpec" schedulers_range = ( schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]