add fmha fwd splitkv receipt for aiter c++ api (#2068)

* add s_randval for c++ api

* Fix bug of bias in splitkv

---------

Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
slippedJim
2025-04-10 23:21:13 +08:00
committed by GitHub
parent f14e648e7c
commit 5f885d2b7a
5 changed files with 30 additions and 10 deletions

View File

@@ -95,8 +95,8 @@ struct FmhaFwdSplitKVKernel
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
#undef _SS_
#undef _TS_
@@ -563,7 +563,7 @@ struct FmhaFwdSplitKVKernel
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias + key_start;
batch_offset_bias = query_start * kargs.stride_bias;
}
batch_offset_lse_acc = query_start;