mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[Inductor] Copy logic for ck-tile gemm instance configuration in Inductor max-autotune integration and test it (#2910)
* add op, gen_instances and test --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
138
python/ck4inductor/ck_tile_universal_gemm/gen_instances.py
Normal file
138
python/ck4inductor/ck_tile_universal_gemm/gen_instances.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import functools
|
||||
from .op import CKTileGemmOperation
|
||||
|
||||
|
||||
@functools.cache
|
||||
def ops():
|
||||
"""
|
||||
Generate the supported instance dataclasses
|
||||
"""
|
||||
import itertools
|
||||
|
||||
compute_v3_instances = [
|
||||
CKTileGemmOperation(
|
||||
layout_a=layout_a,
|
||||
layout_b=layout_b,
|
||||
layout_c=layout_c,
|
||||
datatype_a=datatype_a,
|
||||
datatype_b=datatype_b,
|
||||
datatype_c=datatype_c,
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
tile_k=tile_k,
|
||||
warp_m=warp_m,
|
||||
warp_n=warp_n,
|
||||
warp_k=warp_k,
|
||||
warp_tile_m=warp_tile_m,
|
||||
warp_tile_n=warp_tile_n,
|
||||
warp_tile_k=warp_tile_k,
|
||||
m_is_padded=m_is_padded,
|
||||
n_is_padded=n_is_padded,
|
||||
k_is_padded=k_is_padded,
|
||||
pipeline="CompV3",
|
||||
scheduler="Intrawave",
|
||||
epilogue=epilogue,
|
||||
)
|
||||
for (layout_a, layout_b, layout_c) in [
|
||||
("Row", "Row", "Row"),
|
||||
("Row", "Col", "Row"),
|
||||
]
|
||||
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
|
||||
for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)]
|
||||
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
|
||||
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
|
||||
for m_is_padded in ["true", "false"]
|
||||
for n_is_padded in ["true", "false"]
|
||||
for k_is_padded in ["true", "false"]
|
||||
for epilogue in ["Default", "CShuffle"]
|
||||
]
|
||||
|
||||
compute_v4_instances = [
|
||||
CKTileGemmOperation(
|
||||
layout_a=layout_a,
|
||||
layout_b=layout_b,
|
||||
layout_c=layout_c,
|
||||
datatype_a=datatype_a,
|
||||
datatype_b=datatype_b,
|
||||
datatype_c=datatype_c,
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
tile_k=tile_k,
|
||||
warp_m=warp_m,
|
||||
warp_n=warp_n,
|
||||
warp_k=warp_k,
|
||||
warp_tile_m=warp_tile_m,
|
||||
warp_tile_n=warp_tile_n,
|
||||
warp_tile_k=warp_tile_k,
|
||||
m_is_padded=m_is_padded,
|
||||
n_is_padded=n_is_padded,
|
||||
k_is_padded=k_is_padded,
|
||||
pipeline="CompV4",
|
||||
scheduler="Intrawave",
|
||||
epilogue=epilogue,
|
||||
)
|
||||
for (layout_a, layout_b, layout_c) in [
|
||||
("Row", "Row", "Row"),
|
||||
("Row", "Col", "Row"),
|
||||
]
|
||||
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
|
||||
for (tile_m, tile_n, tile_k) in [
|
||||
(256, 256, 32)
|
||||
] # half the tile size since it has double buffering
|
||||
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
|
||||
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
|
||||
for m_is_padded in ["true", "false"]
|
||||
for n_is_padded in ["true", "false"]
|
||||
for k_is_padded in ["true", "false"]
|
||||
for epilogue in ["Default", "CShuffle"]
|
||||
]
|
||||
|
||||
mem_instances = [
|
||||
CKTileGemmOperation(
|
||||
layout_a=layout_a,
|
||||
layout_b=layout_b,
|
||||
layout_c=layout_c,
|
||||
datatype_a=datatype_a,
|
||||
datatype_b=datatype_b,
|
||||
datatype_c=datatype_c,
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
tile_k=tile_k,
|
||||
warp_m=warp_m,
|
||||
warp_n=warp_n,
|
||||
warp_k=warp_k,
|
||||
warp_tile_m=warp_tile_m,
|
||||
warp_tile_n=warp_tile_n,
|
||||
warp_tile_k=warp_tile_k,
|
||||
m_is_padded=m_is_padded,
|
||||
n_is_padded=n_is_padded,
|
||||
k_is_padded=k_is_padded,
|
||||
pipeline="Mem",
|
||||
scheduler=scheduler,
|
||||
epilogue=epilogue,
|
||||
)
|
||||
for (layout_a, layout_b, layout_c) in [
|
||||
("Row", "Row", "Row"),
|
||||
("Row", "Col", "Row"),
|
||||
]
|
||||
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
|
||||
for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)]
|
||||
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
|
||||
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
|
||||
for m_is_padded in ["true", "false"]
|
||||
for n_is_padded in ["true", "false"]
|
||||
for k_is_padded in ["true", "false"]
|
||||
for scheduler in ["Intrawave", "Interwave"]
|
||||
for epilogue in ["Default", "CShuffle"]
|
||||
]
|
||||
|
||||
return list(
|
||||
itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for op in ops():
|
||||
print(op.name())
|
||||
64
python/ck4inductor/ck_tile_universal_gemm/op.py
Normal file
64
python/ck4inductor/ck_tile_universal_gemm/op.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CKTileGemmOperation:
|
||||
layout_a: str
|
||||
layout_b: str
|
||||
layout_c: str
|
||||
|
||||
datatype_a: str
|
||||
datatype_b: str
|
||||
datatype_c: str
|
||||
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
|
||||
warp_m: int
|
||||
warp_n: int
|
||||
warp_k: int
|
||||
|
||||
warp_tile_m: int
|
||||
warp_tile_n: int
|
||||
warp_tile_k: int
|
||||
|
||||
m_is_padded: str
|
||||
n_is_padded: str
|
||||
k_is_padded: str
|
||||
|
||||
pipeline: str
|
||||
scheduler: str
|
||||
epilogue: str
|
||||
|
||||
def layout_repr(self):
|
||||
return f"{self.layout_a[0]}{self.layout_b[0]}{self.layout_c[0]}"
|
||||
|
||||
def dtype_repr(self):
|
||||
return f"{self.datatype_a}{self.datatype_b}{self.datatype_c}"
|
||||
|
||||
def tile_sizes(self):
|
||||
return "_".join(
|
||||
[
|
||||
f"{self.tile_m}{self.tile_n}{self.tile_k}",
|
||||
f"{self.warp_m}{self.warp_n}{self.warp_k}",
|
||||
f"{self.warp_tile_m}{self.warp_tile_n}{self.warp_tile_k}",
|
||||
]
|
||||
)
|
||||
|
||||
def name(self):
|
||||
return "ck_tile_gemm_universal_" + "_".join(
|
||||
[
|
||||
f"{self.layout_repr()}",
|
||||
f"{self.dtype_repr()}",
|
||||
f"{self.tile_sizes()}",
|
||||
f"{self.pipeline}",
|
||||
f"{self.scheduler}",
|
||||
f"{self.epilogue}",
|
||||
]
|
||||
)
|
||||
|
||||
def dict_items(self):
|
||||
return asdict(self).items()
|
||||
Reference in New Issue
Block a user