[Inductor] Copy logic for ck-tile gemm instance configuration in Inductor max-autotune integration and test it (#2910)

* add op, gen_instances and test

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Max Podkorytov
2025-11-19 09:38:02 -08:00
committed by GitHub
parent 7fe7aa76f5
commit e6e2e04edb
4 changed files with 221 additions and 1 deletions

View File

@@ -0,0 +1,138 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import functools
from .op import CKTileGemmOperation
@functools.cache
def ops():
"""
Generate the supported instance dataclasses
"""
import itertools
compute_v3_instances = [
CKTileGemmOperation(
layout_a=layout_a,
layout_b=layout_b,
layout_c=layout_c,
datatype_a=datatype_a,
datatype_b=datatype_b,
datatype_c=datatype_c,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=warp_k,
warp_tile_m=warp_tile_m,
warp_tile_n=warp_tile_n,
warp_tile_k=warp_tile_k,
m_is_padded=m_is_padded,
n_is_padded=n_is_padded,
k_is_padded=k_is_padded,
pipeline="CompV3",
scheduler="Intrawave",
epilogue=epilogue,
)
for (layout_a, layout_b, layout_c) in [
("Row", "Row", "Row"),
("Row", "Col", "Row"),
]
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)]
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
for m_is_padded in ["true", "false"]
for n_is_padded in ["true", "false"]
for k_is_padded in ["true", "false"]
for epilogue in ["Default", "CShuffle"]
]
compute_v4_instances = [
CKTileGemmOperation(
layout_a=layout_a,
layout_b=layout_b,
layout_c=layout_c,
datatype_a=datatype_a,
datatype_b=datatype_b,
datatype_c=datatype_c,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=warp_k,
warp_tile_m=warp_tile_m,
warp_tile_n=warp_tile_n,
warp_tile_k=warp_tile_k,
m_is_padded=m_is_padded,
n_is_padded=n_is_padded,
k_is_padded=k_is_padded,
pipeline="CompV4",
scheduler="Intrawave",
epilogue=epilogue,
)
for (layout_a, layout_b, layout_c) in [
("Row", "Row", "Row"),
("Row", "Col", "Row"),
]
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
for (tile_m, tile_n, tile_k) in [
(256, 256, 32)
] # half the tile size since it has double buffering
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
for m_is_padded in ["true", "false"]
for n_is_padded in ["true", "false"]
for k_is_padded in ["true", "false"]
for epilogue in ["Default", "CShuffle"]
]
mem_instances = [
CKTileGemmOperation(
layout_a=layout_a,
layout_b=layout_b,
layout_c=layout_c,
datatype_a=datatype_a,
datatype_b=datatype_b,
datatype_c=datatype_c,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=warp_k,
warp_tile_m=warp_tile_m,
warp_tile_n=warp_tile_n,
warp_tile_k=warp_tile_k,
m_is_padded=m_is_padded,
n_is_padded=n_is_padded,
k_is_padded=k_is_padded,
pipeline="Mem",
scheduler=scheduler,
epilogue=epilogue,
)
for (layout_a, layout_b, layout_c) in [
("Row", "Row", "Row"),
("Row", "Col", "Row"),
]
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)]
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
for m_is_padded in ["true", "false"]
for n_is_padded in ["true", "false"]
for k_is_padded in ["true", "false"]
for scheduler in ["Intrawave", "Interwave"]
for epilogue in ["Default", "CShuffle"]
]
return list(
itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances)
)
if __name__ == "__main__":
for op in ops():
print(op.name())

View File

@@ -0,0 +1,64 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
from dataclasses import asdict, dataclass
@dataclass
class CKTileGemmOperation:
layout_a: str
layout_b: str
layout_c: str
datatype_a: str
datatype_b: str
datatype_c: str
tile_m: int
tile_n: int
tile_k: int
warp_m: int
warp_n: int
warp_k: int
warp_tile_m: int
warp_tile_n: int
warp_tile_k: int
m_is_padded: str
n_is_padded: str
k_is_padded: str
pipeline: str
scheduler: str
epilogue: str
def layout_repr(self):
return f"{self.layout_a[0]}{self.layout_b[0]}{self.layout_c[0]}"
def dtype_repr(self):
return f"{self.datatype_a}{self.datatype_b}{self.datatype_c}"
def tile_sizes(self):
return "_".join(
[
f"{self.tile_m}{self.tile_n}{self.tile_k}",
f"{self.warp_m}{self.warp_n}{self.warp_k}",
f"{self.warp_tile_m}{self.warp_tile_n}{self.warp_tile_k}",
]
)
def name(self):
return "ck_tile_gemm_universal_" + "_".join(
[
f"{self.layout_repr()}",
f"{self.dtype_repr()}",
f"{self.tile_sizes()}",
f"{self.pipeline}",
f"{self.scheduler}",
f"{self.epilogue}",
]
)
def dict_items(self):
return asdict(self).items()