update bwd kernel launch

This commit is contained in:
danyao12
2024-05-28 23:14:18 +08:00
parent ba6437868b
commit 1c511b3e7d
4 changed files with 101 additions and 7 deletions

View File

@@ -1187,6 +1187,32 @@ struct FmhaBwdOGradDotOKernel
static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s<ODataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" +
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn);
#undef _SS_
#undef _TS_
// clang-format on
}
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.