fix fmha fwd kernel name (#2880)

* fix fmha fwd kernel name

* if the input and output types are the same, keep the original code
This commit is contained in:
ltqin
2025-09-25 11:00:10 +08:00
committed by GitHub
parent df97a286d5
commit ab22f91a7c

View File

@@ -72,12 +72,14 @@ struct FmhaFwdKernel
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
// clang-format off
template <typename T> struct t2s;
template <typename T1, typename T2 = T1> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::fp8_t, ck_tile::bf16_t> { static constexpr const char * name = "fp8bf16"; };
template <> struct t2s<ck_tile::fp8_t, ck_tile::fp32_t> { static constexpr const char * name = "fp8fp32"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
@@ -99,7 +101,7 @@ struct FmhaFwdKernel
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType, ODataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +