mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
modify python wrapper for addmm (#1441)
[ROCm/composable_kernel commit: 886d14ccb2]
This commit is contained in:
@@ -10,10 +10,12 @@ class CKGemmOperation:
|
||||
|
||||
a_layout: str
|
||||
b_layout: str
|
||||
ds_layouts: Tuple[str] # addmm specific
|
||||
c_layout: str
|
||||
|
||||
a_element_dtype: str
|
||||
b_element_dtype: str
|
||||
ds_element_dtypes: Tuple[str] # addmm specific
|
||||
c_element_dtype: str
|
||||
|
||||
acc_dtype: str
|
||||
@@ -64,16 +66,15 @@ class CKGemmOperation:
|
||||
Tuple[int, int, int, int]
|
||||
)
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
|
||||
|
||||
block_gemm_pipeline_scheduler: str
|
||||
block_gemm_pipeline_version: Optional[str]
|
||||
block_gemm_pipeline_version: str
|
||||
|
||||
a_compute_dtype: Optional[str]
|
||||
b_compute_dtype: Optional[str]
|
||||
a_compute_dtype: Optional[str] = None
|
||||
b_compute_dtype: Optional[str] = None
|
||||
|
||||
def name(self):
|
||||
# cpp alias for template instance
|
||||
return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}"
|
||||
return f"ck_devicegemm_multid_xdl_shuffle_v3_{self.key_name()}"
|
||||
|
||||
def key_name(self):
|
||||
# TBD; must be unique per instance. Intended to use as dict key
|
||||
|
||||
Reference in New Issue
Block a user