From f4b5582b2ae35906ea992200468f3c140d82c815 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 6 Aug 2024 15:09:27 -0700 Subject: [PATCH] modify python wrapper for addmm (#1441) [ROCm/composable_kernel commit: 886d14ccb221033ab4f83c54a03bb04d94af594f] --- python/ck4inductor/universal_gemm/gen_instances.py | 14 ++++++-------- python/ck4inductor/universal_gemm/op.py | 11 ++++++----- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/python/ck4inductor/universal_gemm/gen_instances.py b/python/ck4inductor/universal_gemm/gen_instances.py index 8b6d6b73b2..5594b86817 100644 --- a/python/ck4inductor/universal_gemm/gen_instances.py +++ b/python/ck4inductor/universal_gemm/gen_instances.py @@ -62,17 +62,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]: i_current = i_next + 1 if i_next == -1: break - # pad with `None`s for the fields which are not defined in the instance + + template_args.insert(2, tuple()) # ds layout + template_args.insert(6, tuple()) # ds dtype + new_instance = CKGemmOperation( *template_args, # type: ignore[arg-type] - *((None,) * (len(fields(CKGemmOperation)) - len(template_args))), ) - # the last 2 template parameters are optional - # if they are absent, substitute them with default values from Universal Gemm C++ template declaration - if new_instance.a_compute_dtype is None: - new_instance.a_compute_dtype = new_instance.c_element_dtype - if new_instance.b_compute_dtype is None: - new_instance.b_compute_dtype = new_instance.c_element_dtype op_instances.append(new_instance) return op_instances @@ -208,6 +204,8 @@ def gen_ops_preselected() -> List[CKGemmOperation]: a_layout="Row", b_layout="Col", c_layout="Row", + ds_element_dtypes=tuple(), + ds_layouts=tuple(), a_element_dtype="F16", b_element_dtype="F16", c_element_dtype="F16", diff --git a/python/ck4inductor/universal_gemm/op.py b/python/ck4inductor/universal_gemm/op.py index ab541c5fb9..a8bb725005 100644 --- a/python/ck4inductor/universal_gemm/op.py +++ b/python/ck4inductor/universal_gemm/op.py @@ -10,10 +10,12 @@ class CKGemmOperation: a_layout: str b_layout: str + ds_layouts: Tuple[str] # addmm specific c_layout: str a_element_dtype: str b_element_dtype: str + ds_element_dtypes: Tuple[str] # addmm specific c_element_dtype: str acc_dtype: str @@ -64,16 +66,15 @@ class CKGemmOperation: Tuple[int, int, int, int] ) c_shuffle_block_transfer_scalar_per_vector_n_per_block: int - block_gemm_pipeline_scheduler: str - block_gemm_pipeline_version: Optional[str] + block_gemm_pipeline_version: str - a_compute_dtype: Optional[str] - b_compute_dtype: Optional[str] + a_compute_dtype: Optional[str] = None + b_compute_dtype: Optional[str] = None def name(self): # cpp alias for template instance - return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}" + return f"ck_devicegemm_multid_xdl_shuffle_v3_{self.key_name()}" def key_name(self): # TBD; must be unique per instance. Intended to use as dict key