diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index bee1c77c7b..8f710050b1 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -772,7 +772,7 @@ class FmhaBwdApiPool: per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4) if not per_tr_load: # empty string we add some ignore to suppress warning in api - per_tr_load += ' (void)t ; (void)s ; (void)a;' + per_tr_load += ' (void)t ; (void)s ; (void)a; (void)has_load_tr;' result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load) return result.replace('\n\n', '\n') diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index c1f85cb5e6..a196807b83 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -75,9 +75,9 @@ struct FmhaBwdDQDKDVKernel static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ; static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0)); #if defined(__gfx950__) - static constexpr bool kIsAvialable = true; + static constexpr bool kIsAvailable = true; #else - static constexpr bool kIsAvialable = !kUseTrLoad; + static constexpr bool kIsAvailable = !kUseTrLoad; #endif // clang-format off @@ -113,7 +113,9 @@ struct FmhaBwdDQDKDVKernel "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" + - ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) + + ("o" + _TS_(kBlockPerCu)) + "_" + + ("maxq" + _TS_(kMaxSeqLenQ)) + + (pn.empty() ? "_npad" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) + (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload"); @@ -676,7 +678,7 @@ struct FmhaBwdDQDKDVKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { - if constexpr(kIsAvialable) + if constexpr(kIsAvailable) run_(std::move(kargs)); }