Merge commit 'ab22f91a7c63a34af3198411d064a760b1edebbc' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-25 03:25:33 +00:00
parent 58b3560182
commit 9d8734c878

View File

@@ -72,12 +72,14 @@ struct FmhaFwdKernel
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
// clang-format off
template <typename T> struct t2s;
template <typename T1, typename T2 = T1> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::fp8_t, ck_tile::bf16_t> { static constexpr const char * name = "fp8bf16"; };
template <> struct t2s<ck_tile::fp8_t, ck_tile::fp32_t> { static constexpr const char * name = "fp8fp32"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
@@ -99,7 +101,7 @@ struct FmhaFwdKernel
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType, ODataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +