mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Reorganize project folders (#6)
This commit is contained in:
0
python/ck4inductor/__init__.py
Normal file
0
python/ck4inductor/__init__.py
Normal file
149
python/ck4inductor/batched_universal_gemm/gen_instances.py
Normal file
149
python/ck4inductor/batched_universal_gemm/gen_instances.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import replace
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
from ..util import library_path
|
||||
|
||||
from .op import CKBatchedGemmOperation
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ck_library_dir():
|
||||
gemm_instances_path = os.path.join(
|
||||
library_path(),
|
||||
"src",
|
||||
"tensor_operation_instance",
|
||||
"gpu",
|
||||
"gemm_universal_batched",
|
||||
)
|
||||
if not os.path.exists(gemm_instances_path):
|
||||
log.error("CK library path %s does not exist", gemm_instances_path)
|
||||
return None
|
||||
return gemm_instances_path
|
||||
|
||||
|
||||
def parse_instances(str_instances: List[str]) -> List[CKBatchedGemmOperation]:
|
||||
"""
|
||||
Parse the lines containing Universal Gemm template instances into `CKBatchedGemmOperation` instances
|
||||
"""
|
||||
|
||||
def maybe_int(s):
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
return s
|
||||
|
||||
op_instances = []
|
||||
for line in str_instances:
|
||||
s_template_args = line.split("DeviceBatchedGemmMultiD_Xdl_CShuffle_V3")[
|
||||
-1
|
||||
].strip("<>, ")
|
||||
template_args = []
|
||||
i_current = 0
|
||||
while i_current < len(s_template_args):
|
||||
if s_template_args[i_current] == " ":
|
||||
# skip whitespace
|
||||
i_current += 1
|
||||
continue
|
||||
elif s_template_args[i_current : i_current + 2] == "S<":
|
||||
# parse template S<Index...>
|
||||
i_next = s_template_args.find(">", i_current)
|
||||
template_args.append(
|
||||
tuple(map(int, s_template_args[i_current + 2 : i_next].split(",")))
|
||||
)
|
||||
i_current = i_next + 2
|
||||
else:
|
||||
# all string attributes must be either type aliases or global constants in C++
|
||||
i_next = s_template_args.find(",", i_current)
|
||||
template_args.append(
|
||||
maybe_int(
|
||||
s_template_args[i_current : i_next if i_next != -1 else None]
|
||||
)
|
||||
)
|
||||
if i_next != -1:
|
||||
i_current = i_next + 1
|
||||
if i_next == -1:
|
||||
break
|
||||
|
||||
# ds layout and dtype are parsed as placeholder; reset value
|
||||
template_args[2] = tuple() # ds layout
|
||||
template_args[6] = tuple() # ds dtype
|
||||
|
||||
new_instance = CKBatchedGemmOperation(
|
||||
*template_args, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
op_instances.append(new_instance)
|
||||
return op_instances
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def gen_ops_library() -> List[CKBatchedGemmOperation]:
|
||||
"""
|
||||
Parse the Universal Gemm instances defined in the composable kernel library folder.
|
||||
"""
|
||||
ck_library_dir = _ck_library_dir()
|
||||
if not ck_library_dir:
|
||||
return []
|
||||
|
||||
grep_result = subprocess.run(
|
||||
[
|
||||
"grep",
|
||||
"-inR",
|
||||
"DeviceBatchedGemmMultiD_Xdl_CShuffle_V3",
|
||||
_ck_library_dir(),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
op_instances = parse_instances(grep_result.stdout.strip().split("\n"))
|
||||
|
||||
log.debug("ck instances from library: %d", len(op_instances))
|
||||
|
||||
schedulers = [
|
||||
"BlockGemmPipelineScheduler::Intrawave",
|
||||
"BlockGemmPipelineScheduler::Interwave",
|
||||
]
|
||||
gemm_specs = [
|
||||
"GemmSpecialization::Default",
|
||||
"GemmSpecialization::MPadding",
|
||||
"GemmSpecialization::NPadding",
|
||||
"GemmSpecialization::KPadding",
|
||||
"GemmSpecialization::MNPadding",
|
||||
"GemmSpecialization::MKPadding",
|
||||
"GemmSpecialization::NKPadding",
|
||||
"GemmSpecialization::MNKPadding",
|
||||
]
|
||||
|
||||
# substitute templated args by looping through their domains
|
||||
substitute_instances = []
|
||||
for instance in op_instances:
|
||||
sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
|
||||
sub_spec = instance.gemm_specialization == "GemmSpec"
|
||||
schedulers_range = (
|
||||
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]
|
||||
)
|
||||
spec_range = gemm_specs if sub_spec else [instance.gemm_specialization]
|
||||
for scheduler in schedulers_range:
|
||||
for spec in spec_range:
|
||||
substitute_instances.append(
|
||||
replace(
|
||||
instance,
|
||||
block_gemm_pipeline_scheduler=scheduler,
|
||||
gemm_specialization=spec,
|
||||
)
|
||||
)
|
||||
|
||||
return substitute_instances
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(gen_ops_library())
|
||||
99
python/ck4inductor/batched_universal_gemm/op.py
Normal file
99
python/ck4inductor/batched_universal_gemm/op.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CKBatchedGemmOperation:
|
||||
"""
|
||||
A python dataclass storing the template parameters of a CK Universal Gemm template instance
|
||||
"""
|
||||
|
||||
a_layout: str
|
||||
b_layout: str
|
||||
ds_layouts: Tuple[str] # addmm specific
|
||||
c_layout: str
|
||||
|
||||
a_element_dtype: str
|
||||
b_element_dtype: str
|
||||
ds_element_dtypes: Tuple[str] # addmm specific
|
||||
c_element_dtype: str
|
||||
|
||||
acc_dtype: str
|
||||
c_shuffle_dtype: str
|
||||
|
||||
a_elementwise_op: str
|
||||
b_elementwise_op: str
|
||||
c_elementwise_op: str
|
||||
|
||||
gemm_specialization: str
|
||||
|
||||
block_size: int
|
||||
|
||||
m_per_block: int
|
||||
n_per_block: int
|
||||
k_per_block: int
|
||||
|
||||
a_k1: int
|
||||
b_k1: int
|
||||
|
||||
m_per_xdl: int
|
||||
n_per_xdl: int
|
||||
|
||||
m_xdl_per_wave: int
|
||||
n_xdl_per_wave: int
|
||||
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int]
|
||||
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
|
||||
a_block_transfer_src_access_order: Tuple[int, int, int]
|
||||
a_block_transfer_src_vector_dim: int
|
||||
a_block_transfer_src_scalar_per_vector: int
|
||||
a_block_transfer_dst_scalar_per_vector_ak1: int
|
||||
a_block_lds_extra_m: bool
|
||||
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int]
|
||||
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
|
||||
b_block_transfer_src_access_order: Tuple[int, int, int]
|
||||
|
||||
b_block_transfer_src_vector_dim: int
|
||||
b_block_transfer_src_scalar_per_vector: int
|
||||
b_block_transfer_dst_scalar_per_vector_bk1: int
|
||||
b_block_lds_extra_n: bool
|
||||
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle: int
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle: int
|
||||
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: (
|
||||
Tuple[int, int, int, int]
|
||||
)
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block: Tuple[int]
|
||||
block_gemm_pipeline_scheduler: str
|
||||
block_gemm_pipeline_version: str
|
||||
|
||||
a_compute_dtype: Optional[str] = None
|
||||
b_compute_dtype: Optional[str] = None
|
||||
|
||||
def name(self):
|
||||
# cpp alias for template instance
|
||||
return f"ck_device_batched_gemm_multi_d_xdl_c_shuffle_v3_{self.key_name()}"
|
||||
|
||||
def key_name(self):
|
||||
# TBD; must be unique per instance. Intended to use as dict key
|
||||
return "_".join(
|
||||
[
|
||||
"K"
|
||||
+ field_name.replace("_", "").lower()
|
||||
+ "V"
|
||||
+ (
|
||||
"x".join(map(str, iter(field_value)))
|
||||
if isinstance(field_value, tuple)
|
||||
else str(field_value).replace(":", "")
|
||||
)
|
||||
for field_name, field_value in self.dict_items()
|
||||
]
|
||||
)
|
||||
|
||||
def dict_items(self):
|
||||
return asdict(self).items()
|
||||
165
python/ck4inductor/grouped_conv_fwd/gen_instances.py
Normal file
165
python/ck4inductor/grouped_conv_fwd/gen_instances.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import replace
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
from ..util import library_path
|
||||
|
||||
from .op import CKGroupedConvFwdOp
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ck_conv_instances_path():
|
||||
conv_instances_path = os.path.join( # noqa: F821
|
||||
library_path(),
|
||||
"include",
|
||||
"ck",
|
||||
"library",
|
||||
"tensor_operation_instance",
|
||||
"gpu",
|
||||
"grouped_conv_fwd",
|
||||
)
|
||||
if not os.path.exists(conv_instances_path):
|
||||
log.error(
|
||||
"CK library conv instances path %s does not exist", conv_instances_path
|
||||
)
|
||||
return None
|
||||
return conv_instances_path
|
||||
|
||||
|
||||
def parse_instances(str_instances: List[str]) -> List[CKGroupedConvFwdOp]:
|
||||
"""
|
||||
Parse the lines containing Grouped Convolution Forward template instances
|
||||
into `CKGroupedConvFwdOp` instances
|
||||
"""
|
||||
|
||||
def maybe_int(s):
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
return s
|
||||
|
||||
op_instances = []
|
||||
# TODO: maybe use libclang for parsing C++ code in the future
|
||||
# to avoid this hacky parsing logic below ? :) - copilot
|
||||
for line in str_instances:
|
||||
s_template_args = line.split("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3")[
|
||||
-1
|
||||
].strip("<>, ")
|
||||
template_args = []
|
||||
i_current = 0
|
||||
while i_current < len(s_template_args):
|
||||
if s_template_args[i_current] == " ":
|
||||
# skip whitespace
|
||||
i_current += 1
|
||||
continue
|
||||
elif s_template_args[i_current : i_current + 2] == "S<":
|
||||
# parse template S<Index...>
|
||||
i_next = s_template_args.find(">", i_current)
|
||||
template_args.append(
|
||||
tuple(map(int, s_template_args[i_current + 2 : i_next].split(",")))
|
||||
)
|
||||
i_current = i_next + 2
|
||||
else:
|
||||
# all string attributes must be either type aliases or global constants in C++
|
||||
i_next = s_template_args.find(",", i_current)
|
||||
template_args.append(
|
||||
maybe_int(
|
||||
s_template_args[i_current : i_next if i_next != -1 else None]
|
||||
)
|
||||
)
|
||||
if i_next != -1:
|
||||
i_current = i_next + 1
|
||||
if i_next == -1:
|
||||
break
|
||||
|
||||
template_args[0] = -1 # n_dim_spatial
|
||||
template_args[3] = tuple() # ds_layout
|
||||
template_args[9] = tuple() # ds_element_dtype
|
||||
|
||||
new_instance = CKGroupedConvFwdOp(
|
||||
*template_args, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
op_instances.append(new_instance)
|
||||
return op_instances
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def gen_conv_ops_library() -> List[CKGroupedConvFwdOp]:
|
||||
"""
|
||||
Parse the Grouped Convolution Forward instances
|
||||
defined in the Composable Kernel library folder.
|
||||
"""
|
||||
ck_library_dir = _ck_conv_instances_path()
|
||||
if not ck_library_dir:
|
||||
return []
|
||||
|
||||
grep_result = subprocess.run(
|
||||
[
|
||||
"grep",
|
||||
"-inR",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
ck_library_dir,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
op_instances = parse_instances(grep_result.stdout.strip().split("\n"))
|
||||
|
||||
log.debug("ck instances from library: %d", len(op_instances))
|
||||
|
||||
schedulers = [
|
||||
"BlockGemmPipelineScheduler::Intrawave",
|
||||
"BlockGemmPipelineScheduler::Interwave",
|
||||
]
|
||||
conv_specs = [
|
||||
"ConvolutionForwardSpecialization::Default",
|
||||
"ConvolutionForwardSpecialization::Filter1x1Pad0",
|
||||
"ConvolutionForwardSpecialization::Filter1x1Stride1Pad0",
|
||||
"ConvolutionForwardSpecialization::OddC",
|
||||
]
|
||||
|
||||
# substitute templated args by looping through their domains
|
||||
substitute_instances = []
|
||||
for instance in op_instances:
|
||||
sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
|
||||
sub_spec = instance.conv_forward_specialization == "ConvSpec"
|
||||
schedulers_range = (
|
||||
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]
|
||||
)
|
||||
spec_range = conv_specs if sub_spec else [instance.conv_forward_specialization]
|
||||
for scheduler in schedulers_range:
|
||||
for spec in spec_range:
|
||||
for channels_last in [True, False]:
|
||||
if channels_last:
|
||||
a_layout = "NHWGC"
|
||||
e_layout = "NHWGK"
|
||||
else:
|
||||
a_layout = "NGCHW"
|
||||
e_layout = "NGKHW"
|
||||
substitute_instances.append(
|
||||
replace(
|
||||
instance,
|
||||
block_gemm_pipeline_scheduler=scheduler,
|
||||
conv_forward_specialization=spec,
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
n_dim_spatial=2,
|
||||
a_layout=a_layout,
|
||||
b_layout="GKYXC",
|
||||
e_layout=e_layout,
|
||||
)
|
||||
)
|
||||
|
||||
return substitute_instances
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(gen_conv_ops_library())
|
||||
93
python/ck4inductor/grouped_conv_fwd/op.py
Normal file
93
python/ck4inductor/grouped_conv_fwd/op.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CKGroupedConvFwdOp:
|
||||
n_dim_spatial: int
|
||||
a_layout: str
|
||||
b_layout: str
|
||||
ds_layout: Tuple[str]
|
||||
e_layout: str
|
||||
a_element_dtype: str
|
||||
b_element_dtype: str
|
||||
acc_dtype: str
|
||||
c_shuffle_dtype: str
|
||||
ds_element_dtype: Tuple[str]
|
||||
e_element_dtype: str
|
||||
a_elementwise_op: str
|
||||
b_elementwise_op: str
|
||||
cde_elementwise_op: str
|
||||
conv_forward_specialization: str
|
||||
gemm_specialization: str
|
||||
|
||||
block_size: int
|
||||
m_per_block: int
|
||||
n_per_block: int
|
||||
k_per_block: int
|
||||
ak1: int
|
||||
bk1: int
|
||||
m_per_xdl: int
|
||||
n_per_xdl: int
|
||||
m_xdl_per_wave: int
|
||||
n_xdl_per_wave: int
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int]
|
||||
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
|
||||
a_block_transfer_src_access_order: Tuple[int, int, int]
|
||||
a_block_transfer_src_vector_dim: int
|
||||
a_block_transfer_src_scalar_per_vector: int
|
||||
a_block_transfer_dst_scalar_per_vector_ak1: int
|
||||
a_block_lds_extra_m: bool
|
||||
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int]
|
||||
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
|
||||
b_block_transfer_src_access_order: Tuple[int, int, int]
|
||||
|
||||
b_block_transfer_src_vector_dim: int
|
||||
b_block_transfer_src_scalar_per_vector: int
|
||||
b_block_transfer_dst_scalar_per_vector_bk1: int
|
||||
b_block_lds_extra_n: bool
|
||||
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle: int
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle: int
|
||||
cde_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: Tuple[ # noqa
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
]
|
||||
cde_block_transfer_scalar_per_vector_n_per_block: int
|
||||
block_gemm_pipeline_scheduler: str
|
||||
block_gemm_pipeline_version: str
|
||||
|
||||
a_compute_dtype: Optional[str] = None
|
||||
b_compute_dtype: Optional[str] = None
|
||||
|
||||
def name(self):
|
||||
# cpp alias for template instance
|
||||
return (
|
||||
f"ck_device_grouped_convolution_fwd_multiple_abd_xdl_c_shuffle_v3_"
|
||||
f"{self.key_name()}"
|
||||
)
|
||||
|
||||
def key_name(self):
|
||||
# TBD; must be unique per instance. Intended to use as dict key
|
||||
return "_".join(
|
||||
[
|
||||
"K"
|
||||
+ field_name.replace("_", "").lower()
|
||||
+ "V"
|
||||
+ (
|
||||
"x".join(map(str, iter(field_value)))
|
||||
if isinstance(field_value, tuple)
|
||||
else str(field_value).replace(":", "")
|
||||
)
|
||||
for field_name, field_value in self.dict_items()
|
||||
]
|
||||
)
|
||||
|
||||
def dict_items(self):
|
||||
return asdict(self).items()
|
||||
572
python/ck4inductor/universal_gemm/gen_instances.py
Normal file
572
python/ck4inductor/universal_gemm/gen_instances.py
Normal file
@@ -0,0 +1,572 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import replace
|
||||
from functools import lru_cache, partial
|
||||
from typing import List
|
||||
|
||||
from ..util import library_path
|
||||
|
||||
from .op import CKGemmOperation
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ck_library_dir():
|
||||
gemm_instances_path = os.path.join(
|
||||
library_path(), "src", "tensor_operation_instance", "gpu", "gemm_universal"
|
||||
)
|
||||
if not os.path.exists(gemm_instances_path):
|
||||
log.error("CK library path %s does not exist", gemm_instances_path)
|
||||
return None
|
||||
return gemm_instances_path
|
||||
|
||||
|
||||
def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
|
||||
"""
|
||||
Parse the lines containing Universal Gemm template instances into `CKGemmOperation` instances
|
||||
"""
|
||||
|
||||
def maybe_int(s):
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
return s
|
||||
|
||||
op_instances = []
|
||||
for line in str_instances:
|
||||
s_template_args = line.split("DeviceGemm_Xdl_CShuffleV3")[-1].strip("<>, ")
|
||||
template_args = []
|
||||
i_current = 0
|
||||
while i_current < len(s_template_args):
|
||||
if s_template_args[i_current] == " ":
|
||||
# skip whitespace
|
||||
i_current += 1
|
||||
continue
|
||||
elif s_template_args[i_current : i_current + 2] == "S<":
|
||||
# parse template S<Index...>
|
||||
i_next = s_template_args.find(">", i_current)
|
||||
template_args.append(
|
||||
tuple(map(int, s_template_args[i_current + 2 : i_next].split(",")))
|
||||
)
|
||||
i_current = i_next + 2
|
||||
else:
|
||||
# all string attributes must be either type aliases or global constants in C++
|
||||
i_next = s_template_args.find(",", i_current)
|
||||
template_args.append(
|
||||
maybe_int(
|
||||
s_template_args[i_current : i_next if i_next != -1 else None]
|
||||
)
|
||||
)
|
||||
if i_next != -1:
|
||||
i_current = i_next + 1
|
||||
if i_next == -1:
|
||||
break
|
||||
|
||||
template_args.insert(2, tuple()) # ds layout
|
||||
template_args.insert(6, tuple()) # ds dtype
|
||||
try:
|
||||
new_instance = CKGemmOperation(
|
||||
*template_args, # type: ignore[arg-type]
|
||||
)
|
||||
op_instances.append(new_instance)
|
||||
except TypeError as e:
|
||||
log.debug(f"{e} when parsing {line}")
|
||||
return op_instances
|
||||
|
||||
|
||||
def default_instances() -> List[CKGemmOperation]:
|
||||
# fallback: known working op instance for problem size M=2240 K=256 N=2048
|
||||
# all string attributes must be either type aliases or global constants in C++
|
||||
|
||||
return [
|
||||
CKGemmOperation(
|
||||
a_layout="Row",
|
||||
b_layout="Row",
|
||||
c_layout="Row",
|
||||
a_element_dtype="F16",
|
||||
b_element_dtype="F16",
|
||||
c_element_dtype="F16",
|
||||
a_compute_dtype="F16",
|
||||
b_compute_dtype="F16",
|
||||
acc_dtype="F32",
|
||||
c_shuffle_dtype="F16",
|
||||
a_elementwise_op="PassThrough",
|
||||
b_elementwise_op="PassThrough",
|
||||
c_elementwise_op="PassThrough",
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
block_size=256,
|
||||
m_per_block=224,
|
||||
n_per_block=256,
|
||||
k_per_block=64,
|
||||
a_k1=8,
|
||||
b_k1=2,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=7,
|
||||
n_xdl_per_wave=8,
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1),
|
||||
a_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
|
||||
a_block_transfer_src_access_order=(1, 0, 2),
|
||||
a_block_transfer_src_vector_dim=2,
|
||||
a_block_transfer_src_scalar_per_vector=8,
|
||||
a_block_transfer_dst_scalar_per_vector_ak1=8,
|
||||
a_block_lds_extra_m=0, # type: ignore[arg-type]
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1),
|
||||
b_block_transfer_thread_cluster_arrange_order=(0, 2, 1),
|
||||
b_block_transfer_src_access_order=(0, 2, 1),
|
||||
b_block_transfer_src_vector_dim=1,
|
||||
b_block_transfer_src_scalar_per_vector=8,
|
||||
b_block_transfer_dst_scalar_per_vector_bk1=2,
|
||||
b_block_lds_extra_n=0, # type: ignore[arg-type]
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=2,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def gen_ops_library() -> List[CKGemmOperation]:
|
||||
"""
|
||||
Parse the Universal Gemm instances defined in the composable kernel library folder.
|
||||
"""
|
||||
ck_library_dir = _ck_library_dir()
|
||||
if not ck_library_dir:
|
||||
return []
|
||||
|
||||
grep_result = subprocess.run(
|
||||
[
|
||||
"grep",
|
||||
"-inR",
|
||||
"DeviceGemm_Xdl_CShuffleV3",
|
||||
_ck_library_dir(),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
op_instances = parse_instances(grep_result.stdout.strip().split("\n"))
|
||||
|
||||
log.debug("ck instances from library: %d", len(op_instances))
|
||||
|
||||
schedulers = [
|
||||
"BlockGemmPipelineScheduler::Intrawave",
|
||||
"BlockGemmPipelineScheduler::Interwave",
|
||||
]
|
||||
gemm_specs = [
|
||||
"GemmSpecialization::Default",
|
||||
"GemmSpecialization::MPadding",
|
||||
"GemmSpecialization::NPadding",
|
||||
"GemmSpecialization::KPadding",
|
||||
"GemmSpecialization::MNPadding",
|
||||
"GemmSpecialization::MKPadding",
|
||||
"GemmSpecialization::NKPadding",
|
||||
"GemmSpecialization::MNKPadding",
|
||||
]
|
||||
|
||||
# substitute templated args by looping through their domains
|
||||
substitute_instances = []
|
||||
for instance in op_instances:
|
||||
sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
|
||||
sub_spec = instance.gemm_specialization == "GemmSpec"
|
||||
schedulers_range = (
|
||||
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]
|
||||
)
|
||||
spec_range = gemm_specs if sub_spec else [instance.gemm_specialization]
|
||||
for scheduler in schedulers_range:
|
||||
for spec in spec_range:
|
||||
substitute_instances.append(
|
||||
replace(
|
||||
instance,
|
||||
block_gemm_pipeline_scheduler=scheduler,
|
||||
gemm_specialization=spec,
|
||||
)
|
||||
)
|
||||
|
||||
return substitute_instances
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def gen_ops_preselected() -> List[CKGemmOperation]:
|
||||
"""
|
||||
Manually selected (through benchmarking) F16/F16/F16 Row/Col/Row instances
|
||||
"""
|
||||
ck_gemm_f16_rcr = partial(
|
||||
CKGemmOperation,
|
||||
a_layout="Row",
|
||||
b_layout="Col",
|
||||
c_layout="Row",
|
||||
ds_element_dtypes=tuple(),
|
||||
ds_layouts=tuple(),
|
||||
a_element_dtype="F16",
|
||||
b_element_dtype="F16",
|
||||
c_element_dtype="F16",
|
||||
acc_dtype="F32",
|
||||
c_shuffle_dtype="F16",
|
||||
a_elementwise_op="PassThrough",
|
||||
b_elementwise_op="PassThrough",
|
||||
c_elementwise_op="PassThrough",
|
||||
k_per_block=64,
|
||||
a_k1=8,
|
||||
b_k1=8,
|
||||
a_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
|
||||
a_block_transfer_src_access_order=(1, 0, 2),
|
||||
a_block_transfer_src_vector_dim=2,
|
||||
a_block_transfer_src_scalar_per_vector=8,
|
||||
a_block_transfer_dst_scalar_per_vector_ak1=8,
|
||||
a_block_lds_extra_m=0,
|
||||
b_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
|
||||
b_block_transfer_src_access_order=(1, 0, 2),
|
||||
b_block_transfer_src_vector_dim=2,
|
||||
b_block_transfer_src_scalar_per_vector=8,
|
||||
b_block_transfer_dst_scalar_per_vector_bk1=8,
|
||||
b_block_lds_extra_n=0,
|
||||
a_compute_dtype="F16",
|
||||
b_compute_dtype="F16",
|
||||
)
|
||||
ck_gemm_f16_rcr_compute_friendly = partial(
|
||||
ck_gemm_f16_rcr,
|
||||
block_size=256,
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1),
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1),
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
)
|
||||
ck_gemm_f16_rcr_memory_friendly = partial(
|
||||
ck_gemm_f16_rcr,
|
||||
block_size=128,
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1),
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1),
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Interwave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v2",
|
||||
)
|
||||
ck_gemm_f16_rcr_latency_friendly = partial(
|
||||
ck_gemm_f16_rcr,
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
block_size=128,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1),
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1),
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v1",
|
||||
)
|
||||
return [
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=224,
|
||||
n_per_block=256,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=7,
|
||||
n_xdl_per_wave=8,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=2,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v4",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v5",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v4",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v5",
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
m_per_block=16,
|
||||
n_per_block=32,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=16,
|
||||
n_per_block=32,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=16,
|
||||
n_per_block=64,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=2,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=32,
|
||||
n_per_block=64,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=32,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
m_per_block=32,
|
||||
n_per_block=16,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=32,
|
||||
n_per_block=16,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=64,
|
||||
n_per_block=16,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=2,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
64,
|
||||
1,
|
||||
2,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=64,
|
||||
n_per_block=32,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=128,
|
||||
n_per_block=32,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=2,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_latency_friendly(
|
||||
m_per_block=16,
|
||||
n_per_block=32,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
),
|
||||
ck_gemm_f16_rcr_latency_friendly(
|
||||
m_per_block=32,
|
||||
n_per_block=16,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(gen_ops_library())
|
||||
99
python/ck4inductor/universal_gemm/op.py
Normal file
99
python/ck4inductor/universal_gemm/op.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CKGemmOperation:
|
||||
"""
|
||||
A python dataclass storing the template parameters of a CK Universal Gemm template instance
|
||||
"""
|
||||
|
||||
a_layout: str
|
||||
b_layout: str
|
||||
ds_layouts: Tuple[str] # addmm specific
|
||||
c_layout: str
|
||||
|
||||
a_element_dtype: str
|
||||
b_element_dtype: str
|
||||
ds_element_dtypes: Tuple[str] # addmm specific
|
||||
c_element_dtype: str
|
||||
|
||||
acc_dtype: str
|
||||
c_shuffle_dtype: str
|
||||
|
||||
a_elementwise_op: str
|
||||
b_elementwise_op: str
|
||||
c_elementwise_op: str
|
||||
|
||||
gemm_specialization: str
|
||||
|
||||
block_size: int
|
||||
|
||||
m_per_block: int
|
||||
n_per_block: int
|
||||
k_per_block: int
|
||||
|
||||
a_k1: int
|
||||
b_k1: int
|
||||
|
||||
m_per_xdl: int
|
||||
n_per_xdl: int
|
||||
|
||||
m_xdl_per_wave: int
|
||||
n_xdl_per_wave: int
|
||||
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int]
|
||||
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
|
||||
a_block_transfer_src_access_order: Tuple[int, int, int]
|
||||
a_block_transfer_src_vector_dim: int
|
||||
a_block_transfer_src_scalar_per_vector: int
|
||||
a_block_transfer_dst_scalar_per_vector_ak1: int
|
||||
a_block_lds_extra_m: bool
|
||||
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int]
|
||||
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
|
||||
b_block_transfer_src_access_order: Tuple[int, int, int]
|
||||
|
||||
b_block_transfer_src_vector_dim: int
|
||||
b_block_transfer_src_scalar_per_vector: int
|
||||
b_block_transfer_dst_scalar_per_vector_bk1: int
|
||||
b_block_lds_extra_n: bool
|
||||
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle: int
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle: int
|
||||
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: (
|
||||
Tuple[int, int, int, int]
|
||||
)
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
|
||||
block_gemm_pipeline_scheduler: str
|
||||
block_gemm_pipeline_version: str
|
||||
|
||||
a_compute_dtype: Optional[str] = None
|
||||
b_compute_dtype: Optional[str] = None
|
||||
|
||||
def name(self):
|
||||
# cpp alias for template instance
|
||||
return f"ck_devicegemm_multid_xdl_shuffle_v3_{self.key_name()}"
|
||||
|
||||
def key_name(self):
|
||||
# TBD; must be unique per instance. Intended to use as dict key
|
||||
return "_".join(
|
||||
[
|
||||
"K"
|
||||
+ field_name.replace("_", "").lower()
|
||||
+ "V"
|
||||
+ (
|
||||
"x".join(map(str, iter(field_value)))
|
||||
if isinstance(field_value, tuple)
|
||||
else str(field_value).replace(":", "")
|
||||
)
|
||||
for field_name, field_value in self.dict_items()
|
||||
]
|
||||
)
|
||||
|
||||
def dict_items(self):
|
||||
return asdict(self).items()
|
||||
10
python/ck4inductor/util.py
Normal file
10
python/ck4inductor/util.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
import functools
|
||||
import os
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def library_path():
|
||||
return os.path.join(os.path.dirname(__file__), "library")
|
||||
Reference in New Issue
Block a user