diff --git a/pyproject.toml b/pyproject.toml index e8868ed92d..9e1457b7d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,13 +21,22 @@ dependencies = [] "Bug Tracker" = "https://github.com/rocm/composable_kernel/issues" [tool.setuptools] -packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library", "ck4inductor.universal_gemm", "ck4inductor.batched_universal_gemm", "ck4inductor.grouped_conv_fwd"] +packages = [ + "ck4inductor", + "ck4inductor.include", + "ck4inductor.library", + "ck4inductor.universal_gemm", + "ck4inductor.batched_universal_gemm", + "ck4inductor.grouped_conv_fwd", + "ck4inductor.ck_tile_universal_gemm", +] [tool.setuptools.package-dir] ck4inductor = "python/ck4inductor" "ck4inductor.universal_gemm" = "python/ck4inductor/universal_gemm" "ck4inductor.batched_universal_gemm" = "python/ck4inductor/batched_universal_gemm" "ck4inductor.grouped_conv_fwd" = "python/ck4inductor/grouped_conv_fwd" +"ck4inductor.ck_tile_universal_gemm" = "python/ck4inductor/ck_tile_universal_gemm" "ck4inductor.include" = "include" "ck4inductor.library" = "library" diff --git a/python/ck4inductor/ck_tile_universal_gemm/gen_instances.py b/python/ck4inductor/ck_tile_universal_gemm/gen_instances.py new file mode 100644 index 0000000000..6f68a7dbd3 --- /dev/null +++ b/python/ck4inductor/ck_tile_universal_gemm/gen_instances.py @@ -0,0 +1,138 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +import functools +from .op import CKTileGemmOperation + + +@functools.cache +def ops(): + """ + Generate the supported instance dataclasses + """ + import itertools + + compute_v3_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV3", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + compute_v4_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV4", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [ + (256, 256, 32) + ] # half the tile size since it has double buffering + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + mem_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="Mem", + scheduler=scheduler, + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for scheduler in ["Intrawave", "Interwave"] + for epilogue in ["Default", "CShuffle"] + ] + + return list( + itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances) + ) + + +if __name__ == "__main__": + for op in ops(): + print(op.name()) diff --git a/python/ck4inductor/ck_tile_universal_gemm/op.py b/python/ck4inductor/ck_tile_universal_gemm/op.py new file mode 100644 index 0000000000..651b3e984e --- /dev/null +++ b/python/ck4inductor/ck_tile_universal_gemm/op.py @@ -0,0 +1,64 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +from dataclasses import asdict, dataclass + + +@dataclass +class CKTileGemmOperation: + layout_a: str + layout_b: str + layout_c: str + + datatype_a: str + datatype_b: str + datatype_c: str + + tile_m: int + tile_n: int + tile_k: int + + warp_m: int + warp_n: int + warp_k: int + + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + m_is_padded: str + n_is_padded: str + k_is_padded: str + + pipeline: str + scheduler: str + epilogue: str + + def layout_repr(self): + return f"{self.layout_a[0]}{self.layout_b[0]}{self.layout_c[0]}" + + def dtype_repr(self): + return f"{self.datatype_a}{self.datatype_b}{self.datatype_c}" + + def tile_sizes(self): + return "_".join( + [ + f"{self.tile_m}{self.tile_n}{self.tile_k}", + f"{self.warp_m}{self.warp_n}{self.warp_k}", + f"{self.warp_tile_m}{self.warp_tile_n}{self.warp_tile_k}", + ] + ) + + def name(self): + return "ck_tile_gemm_universal_" + "_".join( + [ + f"{self.layout_repr()}", + f"{self.dtype_repr()}", + f"{self.tile_sizes()}", + f"{self.pipeline}", + f"{self.scheduler}", + f"{self.epilogue}", + ] + ) + + def dict_items(self): + return asdict(self).items() diff --git a/python/test/test_gen_instances.py b/python/test/test_gen_instances.py index 56fe4f7342..201b0ad2a9 100644 --- a/python/test/test_gen_instances.py +++ b/python/test/test_gen_instances.py @@ -15,6 +15,9 @@ from ck4inductor.grouped_conv_fwd.gen_instances import ( from ck4inductor.batched_universal_gemm.gen_instances import ( gen_ops_library as gen_batched_gemm_ops_library, ) +from ck4inductor.ck_tile_universal_gemm.gen_instances import ( + ops as gen_ck_tile_gemm_ops_library, +) log = logging.getLogger(__name__) @@ -43,3 +46,9 @@ class TestGenInstances(unittest.TestCase): log.debug("%d gemm instances from library" % len(instances)) self.assertTrue(instances) + + def test_gen_ck_tile_universal_gemm_instances(self): + instances = gen_ck_tile_gemm_ops_library() + + log.debug("%d ck-tile gemm instances from library" % len(instances)) + self.assertTrue(instances)