diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index e888cbd383..627352e226 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1095,7 +1095,10 @@ class KernelComponentFactoryGfx950( class KernelComponentFactoryGfx11(CompatibilityRuleFactory): - arch = ArchTrait("gfx11") + arch = ArchTrait( + "gfx11", + preprocessor_check="defined(__gfx11__) && !defined(__gfx115__)", + ) _DT_FP16_BF16 = ("fp16", "bf16") @@ -1109,10 +1112,12 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), + FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), + FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)], (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)] } # fmt: skip else: raise ValueError(f"unsupported dtype={dtype}") @@ -1133,12 +1138,25 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): ["t", "f"], ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + # Keep only ttff/tttt for gfx11: ffff path is often similar or worse + # pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip return pipelines +class KernelComponentFactoryGfx115(KernelComponentFactoryGfx11): + arch = ArchTrait("gfx115") + + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + result = super().get_hdim_tile_size_dict(dtype) + if dtype in cls._DT_FP16_BF16: + result[(64, 64)] = [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)] # fmt: skip + result[(256, 256)] = [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)] # fmt: skip + return result + + class KernelComponentFactoryGfx12(CompatibilityRuleFactory): arch = ArchTrait("gfx12") @@ -1230,6 +1248,8 @@ def get_factory(target: str): if target.startswith("gfx9"): return KernelComponentFactoryGfx9 + if target.startswith("gfx115"): + return KernelComponentFactoryGfx115 if target.startswith("gfx11"): return KernelComponentFactoryGfx11 if target.startswith("gfx12"): diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 2d4964f86a..f42526ddf7 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -23,6 +23,7 @@ #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" +#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/wmma/wmma.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp" diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 5069172386..62d7971a8a 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -1141,6 +1141,9 @@ struct gfx103_t struct gfx11_t { }; +struct gfx115_t +{ +}; struct gfx12_t { }; @@ -1174,6 +1177,8 @@ CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx103_t) { return 32; } CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx11_t) { return 32; } +CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx115_t) { return 32; } + CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx12_t) { return 32; } CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; } diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index ed102c86a8..a057ae9052 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -24,6 +24,9 @@ defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) #define __gfx11__ #endif +#if defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) +#define __gfx115__ +#endif #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) #define __gfx12__ #endif