Merge commit '91178b401197c086c4e94096a2005b2266d297a3' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-09 22:11:29 +00:00
parent 818354c0c2
commit 62c8b1c1f6
2 changed files with 7 additions and 5 deletions

View File

@@ -772,7 +772,7 @@ class FmhaBwdApiPool:
per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += ' (void)t ; (void)s ; (void)a;'
per_tr_load += ' (void)t ; (void)s ; (void)a; (void)has_load_tr;'
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
return result.replace('\n\n', '\n')

View File

@@ -75,9 +75,9 @@ struct FmhaBwdDQDKDVKernel
static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
#if defined(__gfx950__)
static constexpr bool kIsAvialable = true;
static constexpr bool kIsAvailable = true;
#else
static constexpr bool kIsAvialable = !kUseTrLoad;
static constexpr bool kIsAvailable = !kUseTrLoad;
#endif
// clang-format off
@@ -113,7 +113,9 @@ struct FmhaBwdDQDKDVKernel
"r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
("o" + _TS_(kBlockPerCu)) + "_" +
("maxq" + _TS_(kMaxSeqLenQ)) +
(pn.empty() ? "_npad" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) +
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
@@ -676,7 +678,7 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if constexpr(kIsAvialable)
if constexpr(kIsAvailable)
run_(std::move(kargs));
}