mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
add unit test for gen instances for gemms
add unit tests for conv and batched gemms
add unit test for preselected gemm instances
apply ruff lint
add license header for the unit test
add inductor pytest to CI
verbose pip install
switch the directory before installing python packages
move the inductor codegen test
try yet another workdir
Update Jenkinsfile
The directory looks right, fixing pip module not found by invoking pip directly
Update Jenkinsfile
invoke pytest directly since the module is not found
Update Dockerfile
Install setuptools
update package structure
bump setuptools
maybe fix data path for library sources
fix library search path for conv instances
fix path in pyproject definition
compare path used in gen_instances with one in pyproject.toml; fix the difference
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
[ROCm/composable_kernel commit: c0b90f130f]
573 lines
20 KiB
Python
573 lines
20 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
from dataclasses import replace
|
|
from functools import lru_cache, partial
|
|
from typing import List
|
|
|
|
from ..util import library_path
|
|
|
|
from .op import CKGemmOperation
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def _ck_library_dir():
|
|
gemm_instances_path = os.path.join(
|
|
library_path(), "src", "tensor_operation_instance", "gpu", "gemm_universal"
|
|
)
|
|
if not os.path.exists(gemm_instances_path):
|
|
log.error("CK library path %s does not exist", gemm_instances_path)
|
|
return None
|
|
return gemm_instances_path
|
|
|
|
|
|
def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
|
|
"""
|
|
Parse the lines containing Universal Gemm template instances into `CKGemmOperation` instances
|
|
"""
|
|
|
|
def maybe_int(s):
|
|
try:
|
|
return int(s)
|
|
except ValueError:
|
|
return s
|
|
|
|
op_instances = []
|
|
for line in str_instances:
|
|
s_template_args = line.split("DeviceGemm_Xdl_CShuffleV3")[-1].strip("<>, ")
|
|
template_args = []
|
|
i_current = 0
|
|
while i_current < len(s_template_args):
|
|
if s_template_args[i_current] == " ":
|
|
# skip whitespace
|
|
i_current += 1
|
|
continue
|
|
elif s_template_args[i_current : i_current + 2] == "S<":
|
|
# parse template S<Index...>
|
|
i_next = s_template_args.find(">", i_current)
|
|
template_args.append(
|
|
tuple(map(int, s_template_args[i_current + 2 : i_next].split(",")))
|
|
)
|
|
i_current = i_next + 2
|
|
else:
|
|
# all string attributes must be either type aliases or global constants in C++
|
|
i_next = s_template_args.find(",", i_current)
|
|
template_args.append(
|
|
maybe_int(
|
|
s_template_args[i_current : i_next if i_next != -1 else None]
|
|
)
|
|
)
|
|
if i_next != -1:
|
|
i_current = i_next + 1
|
|
if i_next == -1:
|
|
break
|
|
|
|
template_args.insert(2, tuple()) # ds layout
|
|
template_args.insert(6, tuple()) # ds dtype
|
|
try:
|
|
new_instance = CKGemmOperation(
|
|
*template_args, # type: ignore[arg-type]
|
|
)
|
|
op_instances.append(new_instance)
|
|
except TypeError as e:
|
|
log.debug(f"{e} when parsing {line}")
|
|
return op_instances
|
|
|
|
|
|
def default_instances() -> List[CKGemmOperation]:
|
|
# fallback: known working op instance for problem size M=2240 K=256 N=2048
|
|
# all string attributes must be either type aliases or global constants in C++
|
|
|
|
return [
|
|
CKGemmOperation(
|
|
a_layout="Row",
|
|
b_layout="Row",
|
|
c_layout="Row",
|
|
a_element_dtype="F16",
|
|
b_element_dtype="F16",
|
|
c_element_dtype="F16",
|
|
a_compute_dtype="F16",
|
|
b_compute_dtype="F16",
|
|
acc_dtype="F32",
|
|
c_shuffle_dtype="F16",
|
|
a_elementwise_op="PassThrough",
|
|
b_elementwise_op="PassThrough",
|
|
c_elementwise_op="PassThrough",
|
|
gemm_specialization="GemmSpecialization::Default",
|
|
block_size=256,
|
|
m_per_block=224,
|
|
n_per_block=256,
|
|
k_per_block=64,
|
|
a_k1=8,
|
|
b_k1=2,
|
|
m_per_xdl=16,
|
|
n_per_xdl=16,
|
|
m_xdl_per_wave=7,
|
|
n_xdl_per_wave=8,
|
|
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1),
|
|
a_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
|
|
a_block_transfer_src_access_order=(1, 0, 2),
|
|
a_block_transfer_src_vector_dim=2,
|
|
a_block_transfer_src_scalar_per_vector=8,
|
|
a_block_transfer_dst_scalar_per_vector_ak1=8,
|
|
a_block_lds_extra_m=0, # type: ignore[arg-type]
|
|
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1),
|
|
b_block_transfer_thread_cluster_arrange_order=(0, 2, 1),
|
|
b_block_transfer_src_access_order=(0, 2, 1),
|
|
b_block_transfer_src_vector_dim=1,
|
|
b_block_transfer_src_scalar_per_vector=8,
|
|
b_block_transfer_dst_scalar_per_vector_bk1=2,
|
|
b_block_lds_extra_n=0, # type: ignore[arg-type]
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=2,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
32,
|
|
1,
|
|
8,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
|
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
|
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
|
)
|
|
]
|
|
|
|
|
|
@lru_cache(None)
|
|
def gen_ops_library() -> List[CKGemmOperation]:
|
|
"""
|
|
Parse the Universal Gemm instances defined in the composable kernel library folder.
|
|
"""
|
|
ck_library_dir = _ck_library_dir()
|
|
if not ck_library_dir:
|
|
return []
|
|
|
|
grep_result = subprocess.run(
|
|
[
|
|
"grep",
|
|
"-inR",
|
|
"DeviceGemm_Xdl_CShuffleV3",
|
|
_ck_library_dir(),
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
|
|
op_instances = parse_instances(grep_result.stdout.strip().split("\n"))
|
|
|
|
log.debug("ck instances from library: %d", len(op_instances))
|
|
|
|
schedulers = [
|
|
"BlockGemmPipelineScheduler::Intrawave",
|
|
"BlockGemmPipelineScheduler::Interwave",
|
|
]
|
|
gemm_specs = [
|
|
"GemmSpecialization::Default",
|
|
"GemmSpecialization::MPadding",
|
|
"GemmSpecialization::NPadding",
|
|
"GemmSpecialization::KPadding",
|
|
"GemmSpecialization::MNPadding",
|
|
"GemmSpecialization::MKPadding",
|
|
"GemmSpecialization::NKPadding",
|
|
"GemmSpecialization::MNKPadding",
|
|
]
|
|
|
|
# substitute templated args by looping through their domains
|
|
substitute_instances = []
|
|
for instance in op_instances:
|
|
sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
|
|
sub_spec = instance.gemm_specialization == "GemmSpec"
|
|
schedulers_range = (
|
|
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]
|
|
)
|
|
spec_range = gemm_specs if sub_spec else [instance.gemm_specialization]
|
|
for scheduler in schedulers_range:
|
|
for spec in spec_range:
|
|
substitute_instances.append(
|
|
replace(
|
|
instance,
|
|
block_gemm_pipeline_scheduler=scheduler,
|
|
gemm_specialization=spec,
|
|
)
|
|
)
|
|
|
|
return substitute_instances
|
|
|
|
|
|
@lru_cache(None)
|
|
def gen_ops_preselected() -> List[CKGemmOperation]:
|
|
"""
|
|
Manually selected (through benchmarking) F16/F16/F16 Row/Col/Row instances
|
|
"""
|
|
ck_gemm_f16_rcr = partial(
|
|
CKGemmOperation,
|
|
a_layout="Row",
|
|
b_layout="Col",
|
|
c_layout="Row",
|
|
ds_element_dtypes=tuple(),
|
|
ds_layouts=tuple(),
|
|
a_element_dtype="F16",
|
|
b_element_dtype="F16",
|
|
c_element_dtype="F16",
|
|
acc_dtype="F32",
|
|
c_shuffle_dtype="F16",
|
|
a_elementwise_op="PassThrough",
|
|
b_elementwise_op="PassThrough",
|
|
c_elementwise_op="PassThrough",
|
|
k_per_block=64,
|
|
a_k1=8,
|
|
b_k1=8,
|
|
a_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
|
|
a_block_transfer_src_access_order=(1, 0, 2),
|
|
a_block_transfer_src_vector_dim=2,
|
|
a_block_transfer_src_scalar_per_vector=8,
|
|
a_block_transfer_dst_scalar_per_vector_ak1=8,
|
|
a_block_lds_extra_m=0,
|
|
b_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
|
|
b_block_transfer_src_access_order=(1, 0, 2),
|
|
b_block_transfer_src_vector_dim=2,
|
|
b_block_transfer_src_scalar_per_vector=8,
|
|
b_block_transfer_dst_scalar_per_vector_bk1=8,
|
|
b_block_lds_extra_n=0,
|
|
a_compute_dtype="F16",
|
|
b_compute_dtype="F16",
|
|
)
|
|
ck_gemm_f16_rcr_compute_friendly = partial(
|
|
ck_gemm_f16_rcr,
|
|
block_size=256,
|
|
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1),
|
|
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1),
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
32,
|
|
1,
|
|
8,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
|
)
|
|
ck_gemm_f16_rcr_memory_friendly = partial(
|
|
ck_gemm_f16_rcr,
|
|
block_size=128,
|
|
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1),
|
|
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1),
|
|
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Interwave",
|
|
block_gemm_pipeline_version="BlockGemmPipelineVersion::v2",
|
|
)
|
|
ck_gemm_f16_rcr_latency_friendly = partial(
|
|
ck_gemm_f16_rcr,
|
|
gemm_specialization="GemmSpecialization::Default",
|
|
block_size=128,
|
|
m_per_xdl=16,
|
|
n_per_xdl=16,
|
|
m_xdl_per_wave=1,
|
|
n_xdl_per_wave=1,
|
|
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1),
|
|
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1),
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
|
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
|
block_gemm_pipeline_version="BlockGemmPipelineVersion::v1",
|
|
)
|
|
return [
|
|
ck_gemm_f16_rcr_compute_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=224,
|
|
n_per_block=256,
|
|
m_per_xdl=16,
|
|
n_per_xdl=16,
|
|
m_xdl_per_wave=7,
|
|
n_xdl_per_wave=8,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=2,
|
|
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
|
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
|
),
|
|
ck_gemm_f16_rcr_compute_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=128,
|
|
n_per_block=128,
|
|
m_per_xdl=32,
|
|
n_per_xdl=32,
|
|
m_xdl_per_wave=2,
|
|
n_xdl_per_wave=2,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
|
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
|
),
|
|
ck_gemm_f16_rcr_compute_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=128,
|
|
n_per_block=128,
|
|
m_per_xdl=32,
|
|
n_per_xdl=32,
|
|
m_xdl_per_wave=2,
|
|
n_xdl_per_wave=2,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
|
block_gemm_pipeline_version="BlockGemmPipelineVersion::v4",
|
|
),
|
|
ck_gemm_f16_rcr_compute_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=128,
|
|
n_per_block=128,
|
|
m_per_xdl=32,
|
|
n_per_xdl=32,
|
|
m_xdl_per_wave=2,
|
|
n_xdl_per_wave=2,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
|
block_gemm_pipeline_version="BlockGemmPipelineVersion::v5",
|
|
),
|
|
ck_gemm_f16_rcr_compute_friendly(
|
|
gemm_specialization="GemmSpecialization::Default",
|
|
m_per_block=128,
|
|
n_per_block=128,
|
|
m_per_xdl=32,
|
|
n_per_xdl=32,
|
|
m_xdl_per_wave=2,
|
|
n_xdl_per_wave=2,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
|
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
|
),
|
|
ck_gemm_f16_rcr_compute_friendly(
|
|
gemm_specialization="GemmSpecialization::Default",
|
|
m_per_block=128,
|
|
n_per_block=128,
|
|
m_per_xdl=32,
|
|
n_per_xdl=32,
|
|
m_xdl_per_wave=2,
|
|
n_xdl_per_wave=2,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
|
block_gemm_pipeline_version="BlockGemmPipelineVersion::v4",
|
|
),
|
|
ck_gemm_f16_rcr_compute_friendly(
|
|
gemm_specialization="GemmSpecialization::Default",
|
|
m_per_block=128,
|
|
n_per_block=128,
|
|
m_per_xdl=32,
|
|
n_per_xdl=32,
|
|
m_xdl_per_wave=2,
|
|
n_xdl_per_wave=2,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
|
block_gemm_pipeline_version="BlockGemmPipelineVersion::v5",
|
|
),
|
|
ck_gemm_f16_rcr_memory_friendly(
|
|
gemm_specialization="GemmSpecialization::Default",
|
|
m_per_block=16,
|
|
n_per_block=32,
|
|
m_per_xdl=16,
|
|
n_per_xdl=16,
|
|
m_xdl_per_wave=1,
|
|
n_xdl_per_wave=1,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
16,
|
|
1,
|
|
8,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
|
),
|
|
ck_gemm_f16_rcr_memory_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=16,
|
|
n_per_block=32,
|
|
m_per_xdl=16,
|
|
n_per_xdl=16,
|
|
m_xdl_per_wave=1,
|
|
n_xdl_per_wave=1,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
16,
|
|
1,
|
|
8,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
|
),
|
|
ck_gemm_f16_rcr_memory_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=16,
|
|
n_per_block=64,
|
|
m_per_xdl=16,
|
|
n_per_xdl=16,
|
|
m_xdl_per_wave=1,
|
|
n_xdl_per_wave=2,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=2,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
16,
|
|
1,
|
|
8,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
|
),
|
|
ck_gemm_f16_rcr_memory_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=32,
|
|
n_per_block=64,
|
|
m_per_xdl=32,
|
|
n_per_xdl=32,
|
|
m_xdl_per_wave=1,
|
|
n_xdl_per_wave=1,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
16,
|
|
1,
|
|
8,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
|
),
|
|
ck_gemm_f16_rcr_memory_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=32,
|
|
n_per_block=128,
|
|
m_per_xdl=32,
|
|
n_per_xdl=32,
|
|
m_xdl_per_wave=1,
|
|
n_xdl_per_wave=2,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
16,
|
|
1,
|
|
8,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
|
),
|
|
ck_gemm_f16_rcr_memory_friendly(
|
|
gemm_specialization="GemmSpecialization::Default",
|
|
m_per_block=32,
|
|
n_per_block=16,
|
|
m_per_xdl=16,
|
|
n_per_xdl=16,
|
|
m_xdl_per_wave=1,
|
|
n_xdl_per_wave=1,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
32,
|
|
1,
|
|
4,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
|
),
|
|
ck_gemm_f16_rcr_memory_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=32,
|
|
n_per_block=16,
|
|
m_per_xdl=16,
|
|
n_per_xdl=16,
|
|
m_xdl_per_wave=1,
|
|
n_xdl_per_wave=1,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
32,
|
|
1,
|
|
4,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
|
),
|
|
ck_gemm_f16_rcr_memory_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=64,
|
|
n_per_block=16,
|
|
m_per_xdl=16,
|
|
n_per_xdl=16,
|
|
m_xdl_per_wave=2,
|
|
n_xdl_per_wave=1,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=2,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
64,
|
|
1,
|
|
2,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
|
),
|
|
ck_gemm_f16_rcr_memory_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=64,
|
|
n_per_block=32,
|
|
m_per_xdl=32,
|
|
n_per_xdl=32,
|
|
m_xdl_per_wave=1,
|
|
n_xdl_per_wave=1,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
32,
|
|
1,
|
|
4,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
|
),
|
|
ck_gemm_f16_rcr_memory_friendly(
|
|
gemm_specialization="GemmSpecialization::MNKPadding",
|
|
m_per_block=128,
|
|
n_per_block=32,
|
|
m_per_xdl=32,
|
|
n_per_xdl=32,
|
|
m_xdl_per_wave=2,
|
|
n_xdl_per_wave=1,
|
|
c_shuffle_m_xdl_per_wave_per_shuffle=2,
|
|
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
32,
|
|
1,
|
|
4,
|
|
),
|
|
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
|
),
|
|
ck_gemm_f16_rcr_latency_friendly(
|
|
m_per_block=16,
|
|
n_per_block=32,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
16,
|
|
1,
|
|
8,
|
|
),
|
|
),
|
|
ck_gemm_f16_rcr_latency_friendly(
|
|
m_per_block=32,
|
|
n_per_block=16,
|
|
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
|
1,
|
|
32,
|
|
1,
|
|
4,
|
|
),
|
|
),
|
|
]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print(gen_ops_library())
|