diff --git a/python/ck4inductor/grouped_conv_fwd/gen_instances.py b/python/ck4inductor/grouped_conv_fwd/gen_instances.py new file mode 100644 index 0000000000..ffbea6bdc7 --- /dev/null +++ b/python/ck4inductor/grouped_conv_fwd/gen_instances.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +import logging +import os +import subprocess +from dataclasses import replace +from functools import lru_cache +from typing import List + +from ..util import library_path + +from .op import CKGroupedConvFwdOp + +log = logging.getLogger(__name__) + + +def _ck_conv_instances_path(): + conv_instances_path = os.path.join( # noqa: F821 + library_path(), + "include", + "ck", + "library", + "tensor_operation_instance", + "gpu", + "grouped_conv_fwd", + ) + if not os.path.exists(conv_instances_path): + log.error( + "CK library conv instances path %s does not exist", conv_instances_path + ) + return None + return conv_instances_path + + +def parse_instances(str_instances: List[str]) -> List[CKGroupedConvFwdOp]: + """ + Parse the lines containing Grouped Convolution Forward template instances + into `CKGroupedConvFwdOp` instances + """ + + def maybe_int(s): + try: + return int(s) + except ValueError: + return s + + op_instances = [] + # TODO: maybe use libclang for parsing C++ code in the future + # to avoid this hacky parsing logic below ? :) - copilot + for line in str_instances: + s_template_args = line.split("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3")[ + -1 + ].strip("<>, ") + template_args = [] + i_current = 0 + while i_current < len(s_template_args): + if s_template_args[i_current] == " ": + # skip whitespace + i_current += 1 + continue + elif s_template_args[i_current : i_current + 2] == "S<": + # parse template S + i_next = s_template_args.find(">", i_current) + template_args.append( + tuple(map(int, s_template_args[i_current + 2 : i_next].split(","))) + ) + i_current = i_next + 2 + else: + # all string attributes must be either type aliases or global constants in C++ + i_next = s_template_args.find(",", i_current) + template_args.append( + maybe_int( + s_template_args[i_current : i_next if i_next != -1 else None] + ) + ) + if i_next != -1: + i_current = i_next + 1 + if i_next == -1: + break + + template_args[0] = -1 # n_dim_spatial + template_args[3] = tuple() # ds_layout + template_args[9] = tuple() # ds_element_dtype + + new_instance = CKGroupedConvFwdOp( + *template_args, # type: ignore[arg-type] + ) + + op_instances.append(new_instance) + return op_instances + + +@lru_cache(None) +def gen_conv_ops_library() -> List[CKGroupedConvFwdOp]: + """ + Parse the Grouped Convolution Forward instances + defined in the Composable Kernel library folder. + """ + ck_library_dir = _ck_conv_instances_path() + if not ck_library_dir: + return [] + + grep_result = subprocess.run( + [ + "grep", + "-inR", + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + ck_library_dir, + ], + capture_output=True, + text=True, + ) + + op_instances = parse_instances(grep_result.stdout.strip().split("\n")) + + log.debug("ck instances from library: %d", len(op_instances)) + + schedulers = [ + "BlockGemmPipelineScheduler::Intrawave", + "BlockGemmPipelineScheduler::Interwave", + ] + conv_specs = [ + "ConvolutionForwardSpecialization::Default", + "ConvolutionForwardSpecialization::Filter1x1Pad0", + "ConvolutionForwardSpecialization::Filter1x1Stride1Pad0", + "ConvolutionForwardSpecialization::OddC", + ] + + # substitute templated args by looping through their domains + substitute_instances = [] + for instance in op_instances: + sub_scheduler = ( + instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched" + ) + sub_spec = instance.conv_forward_specialization == "ConvSpec" + schedulers_range = ( + schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler] + ) + spec_range = conv_specs if sub_spec else [instance.conv_forward_specialization] + for scheduler in schedulers_range: + for spec in spec_range: + for channels_last in [True, False]: + if channels_last: + a_layout = "NHWGC" + e_layout = "NHWGK" + else: + a_layout = "NGCHW" + e_layout = "NGKHW" + substitute_instances.append( + replace( + instance, + block_gemm_pipeline_scheduler=scheduler, + conv_forward_specialization=spec, + gemm_specialization="GemmSpecialization::MNKPadding", + n_dim_spatial=2, + a_layout=a_layout, + b_layout="GKYXC", + e_layout=e_layout, + ) + ) + + return substitute_instances + + +if __name__ == "__main__": + print(gen_conv_ops_library()) diff --git a/python/ck4inductor/grouped_conv_fwd/op.py b/python/ck4inductor/grouped_conv_fwd/op.py new file mode 100644 index 0000000000..25d45e8ffa --- /dev/null +++ b/python/ck4inductor/grouped_conv_fwd/op.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +from dataclasses import asdict, dataclass +from typing import Optional, Tuple + + +@dataclass +class CKGroupedConvFwdOp: + n_dim_spatial: int + a_layout: str + b_layout: str + ds_layout: Tuple[str] + e_layout: str + a_element_dtype: str + b_element_dtype: str + acc_dtype: str + c_shuffle_dtype: str + ds_element_dtype: Tuple[str] + e_element_dtype: str + a_elementwise_op: str + b_elementwise_op: str + cde_elementwise_op: str + conv_forward_specialization: str + gemm_specialization: str + + block_size: int + m_per_block: int + n_per_block: int + k_per_block: int + ak1: int + bk1: int + m_per_xdl: int + n_per_xdl: int + m_xdl_per_wave: int + n_xdl_per_wave: int + a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int] + a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + a_block_transfer_src_access_order: Tuple[int, int, int] + a_block_transfer_src_vector_dim: int + a_block_transfer_src_scalar_per_vector: int + a_block_transfer_dst_scalar_per_vector_ak1: int + a_block_lds_extra_m: bool + + b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int] + b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + b_block_transfer_src_access_order: Tuple[int, int, int] + + b_block_transfer_src_vector_dim: int + b_block_transfer_src_scalar_per_vector: int + b_block_transfer_dst_scalar_per_vector_bk1: int + b_block_lds_extra_n: bool + + c_shuffle_m_xdl_per_wave_per_shuffle: int + c_shuffle_n_xdl_per_wave_per_shuffle: int + cde_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: Tuple[ # noqa + int, + int, + int, + int, + ] + cde_block_transfer_scalar_per_vector_n_per_block: int + block_gemm_pipeline_scheduler: str + block_gemm_pipeline_version: str + + a_compute_dtype: Optional[str] = None + b_compute_dtype: Optional[str] = None + + def name(self): + # cpp alias for template instance + return ( + f"ck_device_grouped_convolution_fwd_multiple_abd_xdl_c_shuffle_v3_" + f"{self.key_name()}" + ) + + def key_name(self): + # TBD; must be unique per instance. Intended to use as dict key + return "_".join( + [ + "K" + + field_name.replace("_", "").lower() + + "V" + + ( + "x".join(map(str, iter(field_value))) + if isinstance(field_value, tuple) + else str(field_value).replace(":", "") + ) + for field_name, field_value in self.dict_items() + ] + ) + + def dict_items(self): + return asdict(self).items() diff --git a/python/ck4inductor/universal_gemm/gen_instances.py b/python/ck4inductor/universal_gemm/gen_instances.py index 5594b86817..24bab54776 100644 --- a/python/ck4inductor/universal_gemm/gen_instances.py +++ b/python/ck4inductor/universal_gemm/gen_instances.py @@ -1,7 +1,10 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + import logging import os import subprocess -from dataclasses import fields, replace +from dataclasses import replace from functools import lru_cache, partial from typing import List diff --git a/python/ck4inductor/universal_gemm/op.py b/python/ck4inductor/universal_gemm/op.py index a8bb725005..946aaa7afb 100644 --- a/python/ck4inductor/universal_gemm/op.py +++ b/python/ck4inductor/universal_gemm/op.py @@ -1,3 +1,6 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + from dataclasses import asdict, dataclass from typing import Optional, Tuple diff --git a/python/ck4inductor/util.py b/python/ck4inductor/util.py index 79d6be00f3..4d7e8bd87d 100644 --- a/python/ck4inductor/util.py +++ b/python/ck4inductor/util.py @@ -1,7 +1,10 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + import functools import os @functools.lru_cache(None) def library_path(): - return os.path.join(os.path.dirname(__file__), 'library') + return os.path.join(os.path.dirname(__file__), "library")