Files
Aviral Goel 3aa883b9ff chore(copyright): update copyright header for python directory (#3200)
* chore(copyright): update copyright header for tile_engine directory

* chore(copyright): update copyright header for script directory

* chore(copyright): update copyright header for test_data directory

* chore(copyright): update copyright header for python directory
2025-11-14 08:21:36 -08:00

100 lines
2.8 KiB
Python

# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
from dataclasses import asdict, dataclass
from typing import Optional, Tuple
@dataclass
class CKGemmOperation:
"""
A python dataclass storing the template parameters of a CK Universal Gemm template instance
"""
a_layout: str
b_layout: str
ds_layouts: Tuple[str] # addmm specific
c_layout: str
a_element_dtype: str
b_element_dtype: str
ds_element_dtypes: Tuple[str] # addmm specific
c_element_dtype: str
acc_dtype: str
c_shuffle_dtype: str
a_elementwise_op: str
b_elementwise_op: str
c_elementwise_op: str
gemm_specialization: str
block_size: int
m_per_block: int
n_per_block: int
k_per_block: int
a_k1: int
b_k1: int
m_per_xdl: int
n_per_xdl: int
m_xdl_per_wave: int
n_xdl_per_wave: int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int]
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
a_block_transfer_src_access_order: Tuple[int, int, int]
a_block_transfer_src_vector_dim: int
a_block_transfer_src_scalar_per_vector: int
a_block_transfer_dst_scalar_per_vector_ak1: int
a_block_lds_extra_m: bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int]
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
b_block_transfer_src_access_order: Tuple[int, int, int]
b_block_transfer_src_vector_dim: int
b_block_transfer_src_scalar_per_vector: int
b_block_transfer_dst_scalar_per_vector_bk1: int
b_block_lds_extra_n: bool
c_shuffle_m_xdl_per_wave_per_shuffle: int
c_shuffle_n_xdl_per_wave_per_shuffle: int
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: (
Tuple[int, int, int, int]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
block_gemm_pipeline_scheduler: str
block_gemm_pipeline_version: str
a_compute_dtype: Optional[str] = None
b_compute_dtype: Optional[str] = None
def name(self):
# cpp alias for template instance
return f"ck_devicegemm_multid_xdl_shuffle_v3_{self.key_name()}"
def key_name(self):
# TBD; must be unique per instance. Intended to use as dict key
return "_".join(
[
"K"
+ field_name.replace("_", "").lower()
+ "V"
+ (
"x".join(map(str, iter(field_value)))
if isinstance(field_value, tuple)
else str(field_value).replace(":", "")
)
for field_name, field_value in self.dict_items()
]
)
def dict_items(self):
return asdict(self).items()