modify python wrapper for addmm (#1441)

[ROCm/composable_kernel commit: 886d14ccb2]
This commit is contained in:
Max Podkorytov
2024-08-06 15:09:27 -07:00
committed by GitHub
parent 62bd135b8c
commit 3aeb97f3ec
2 changed files with 12 additions and 13 deletions

View File

@@ -10,10 +10,12 @@ class CKGemmOperation:
a_layout: str
b_layout: str
ds_layouts: Tuple[str] # addmm specific
c_layout: str
a_element_dtype: str
b_element_dtype: str
ds_element_dtypes: Tuple[str] # addmm specific
c_element_dtype: str
acc_dtype: str
@@ -64,16 +66,15 @@ class CKGemmOperation:
Tuple[int, int, int, int]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
block_gemm_pipeline_scheduler: str
block_gemm_pipeline_version: Optional[str]
block_gemm_pipeline_version: str
a_compute_dtype: Optional[str]
b_compute_dtype: Optional[str]
a_compute_dtype: Optional[str] = None
b_compute_dtype: Optional[str] = None
def name(self):
# cpp alias for template instance
return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}"
return f"ck_devicegemm_multid_xdl_shuffle_v3_{self.key_name()}"
def key_name(self):
# TBD; must be unique per instance. Intended to use as dict key