mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-22 16:17:37 +00:00
* chore(copyright): update copyright header for tile_engine directory * chore(copyright): update copyright header for script directory * chore(copyright): update copyright header for test_data directory * chore(copyright): update copyright header for python directory
100 lines
2.8 KiB
Python
100 lines
2.8 KiB
Python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
@dataclass
|
|
class CKGemmOperation:
|
|
"""
|
|
A python dataclass storing the template parameters of a CK Universal Gemm template instance
|
|
"""
|
|
|
|
a_layout: str
|
|
b_layout: str
|
|
ds_layouts: Tuple[str] # addmm specific
|
|
c_layout: str
|
|
|
|
a_element_dtype: str
|
|
b_element_dtype: str
|
|
ds_element_dtypes: Tuple[str] # addmm specific
|
|
c_element_dtype: str
|
|
|
|
acc_dtype: str
|
|
c_shuffle_dtype: str
|
|
|
|
a_elementwise_op: str
|
|
b_elementwise_op: str
|
|
c_elementwise_op: str
|
|
|
|
gemm_specialization: str
|
|
|
|
block_size: int
|
|
|
|
m_per_block: int
|
|
n_per_block: int
|
|
k_per_block: int
|
|
|
|
a_k1: int
|
|
b_k1: int
|
|
|
|
m_per_xdl: int
|
|
n_per_xdl: int
|
|
|
|
m_xdl_per_wave: int
|
|
n_xdl_per_wave: int
|
|
|
|
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int]
|
|
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
|
|
a_block_transfer_src_access_order: Tuple[int, int, int]
|
|
a_block_transfer_src_vector_dim: int
|
|
a_block_transfer_src_scalar_per_vector: int
|
|
a_block_transfer_dst_scalar_per_vector_ak1: int
|
|
a_block_lds_extra_m: bool
|
|
|
|
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int]
|
|
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
|
|
b_block_transfer_src_access_order: Tuple[int, int, int]
|
|
|
|
b_block_transfer_src_vector_dim: int
|
|
b_block_transfer_src_scalar_per_vector: int
|
|
b_block_transfer_dst_scalar_per_vector_bk1: int
|
|
b_block_lds_extra_n: bool
|
|
|
|
c_shuffle_m_xdl_per_wave_per_shuffle: int
|
|
c_shuffle_n_xdl_per_wave_per_shuffle: int
|
|
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: (
|
|
Tuple[int, int, int, int]
|
|
)
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
|
|
block_gemm_pipeline_scheduler: str
|
|
block_gemm_pipeline_version: str
|
|
|
|
a_compute_dtype: Optional[str] = None
|
|
b_compute_dtype: Optional[str] = None
|
|
|
|
def name(self):
|
|
# cpp alias for template instance
|
|
return f"ck_devicegemm_multid_xdl_shuffle_v3_{self.key_name()}"
|
|
|
|
def key_name(self):
|
|
# TBD; must be unique per instance. Intended to use as dict key
|
|
return "_".join(
|
|
[
|
|
"K"
|
|
+ field_name.replace("_", "").lower()
|
|
+ "V"
|
|
+ (
|
|
"x".join(map(str, iter(field_value)))
|
|
if isinstance(field_value, tuple)
|
|
else str(field_value).replace(":", "")
|
|
)
|
|
for field_name, field_value in self.dict_items()
|
|
]
|
|
)
|
|
|
|
def dict_items(self):
|
|
return asdict(self).items()
|