mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
modify python wrapper for addmm (#1441)
This commit is contained in:
@@ -62,17 +62,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
|
||||
i_current = i_next + 1
|
||||
if i_next == -1:
|
||||
break
|
||||
# pad with `None`s for the fields which are not defined in the instance
|
||||
|
||||
template_args.insert(2, tuple()) # ds layout
|
||||
template_args.insert(6, tuple()) # ds dtype
|
||||
|
||||
new_instance = CKGemmOperation(
|
||||
*template_args, # type: ignore[arg-type]
|
||||
*((None,) * (len(fields(CKGemmOperation)) - len(template_args))),
|
||||
)
|
||||
# the last 2 template parameters are optional
|
||||
# if they are absent, substitute them with default values from Universal Gemm C++ template declaration
|
||||
if new_instance.a_compute_dtype is None:
|
||||
new_instance.a_compute_dtype = new_instance.c_element_dtype
|
||||
if new_instance.b_compute_dtype is None:
|
||||
new_instance.b_compute_dtype = new_instance.c_element_dtype
|
||||
|
||||
op_instances.append(new_instance)
|
||||
return op_instances
|
||||
@@ -208,6 +204,8 @@ def gen_ops_preselected() -> List[CKGemmOperation]:
|
||||
a_layout="Row",
|
||||
b_layout="Col",
|
||||
c_layout="Row",
|
||||
ds_element_dtypes=tuple(),
|
||||
ds_layouts=tuple(),
|
||||
a_element_dtype="F16",
|
||||
b_element_dtype="F16",
|
||||
c_element_dtype="F16",
|
||||
|
||||
@@ -10,10 +10,12 @@ class CKGemmOperation:
|
||||
|
||||
a_layout: str
|
||||
b_layout: str
|
||||
ds_layouts: Tuple[str] # addmm specific
|
||||
c_layout: str
|
||||
|
||||
a_element_dtype: str
|
||||
b_element_dtype: str
|
||||
ds_element_dtypes: Tuple[str] # addmm specific
|
||||
c_element_dtype: str
|
||||
|
||||
acc_dtype: str
|
||||
@@ -64,16 +66,15 @@ class CKGemmOperation:
|
||||
Tuple[int, int, int, int]
|
||||
)
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
|
||||
|
||||
block_gemm_pipeline_scheduler: str
|
||||
block_gemm_pipeline_version: Optional[str]
|
||||
block_gemm_pipeline_version: str
|
||||
|
||||
a_compute_dtype: Optional[str]
|
||||
b_compute_dtype: Optional[str]
|
||||
a_compute_dtype: Optional[str] = None
|
||||
b_compute_dtype: Optional[str] = None
|
||||
|
||||
def name(self):
|
||||
# cpp alias for template instance
|
||||
return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}"
|
||||
return f"ck_devicegemm_multid_xdl_shuffle_v3_{self.key_name()}"
|
||||
|
||||
def key_name(self):
|
||||
# TBD; must be unique per instance. Intended to use as dict key
|
||||
|
||||
Reference in New Issue
Block a user