From 712cbfb30498aefa6a13fd2301740a2ca6bc3ae8 Mon Sep 17 00:00:00 2001 From: ltqin Date: Thu, 25 Sep 2025 11:00:10 +0800 Subject: [PATCH] fix fmha fwd kernel name (#2880) * fix fmha fwd kernel name * if the input and output types are the same, keep the original code [ROCm/composable_kernel commit: ab22f91a7c63a34af3198411d064a760b1edebbc] --- include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index e562f6dd5a..29950435fa 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -72,12 +72,14 @@ struct FmhaFwdKernel static constexpr std::string_view kPipelineName = FmhaPipeline::name; // clang-format off - template struct t2s; + template struct t2s; template <> struct t2s { static constexpr const char * name = "fp32"; }; template <> struct t2s { static constexpr const char * name = "fp16"; }; template <> struct t2s { static constexpr const char * name = "bf16"; }; template <> struct t2s { static constexpr const char * name = "fp8"; }; template <> struct t2s { static constexpr const char * name = "bf8"; }; + template <> struct t2s { static constexpr const char * name = "fp8bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8fp32"; }; // clang-format on CK_TILE_HOST static std::string GetName() @@ -99,7 +101,7 @@ struct FmhaFwdKernel if (kPadHeadDimV) n += "dv"; return n.empty() ? n : std::string("p") + n; }(); return - _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + "_" + (kIsGroupMode ? "group" : "batch") + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +