This commit is contained in:
Juuso Korhonen
2025-11-24 10:20:04 +00:00
parent f552cd7841
commit f2fbc44b7b
3 changed files with 23 additions and 21 deletions

View File

@@ -238,14 +238,14 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
-1,
0,
seqlen_q,
seqlen_kv,
true));
// ck_tile::reference_batched_masking(
// s_host_ref,
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
// -1,
// 0,
// seqlen_q,
// seqlen_kv,
// true));
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
@@ -526,6 +526,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());
const auto [rtol, atol] = [&] {
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
return std::make_tuple(1e-3, 1e-3);