Move dividing by max_seqlen to end of Gemm1 loop in the reference hstu-attention codes

This commit is contained in:
Qianfeng Zhang
2025-05-30 16:02:45 +00:00
parent bec35abafe
commit 9582ae2dff
2 changed files with 5 additions and 3 deletions

View File

@@ -199,8 +199,8 @@ auto get_elimit()
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 2e-3;
double atol = 2e-3;
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}

View File

@@ -102,7 +102,7 @@ struct reference_hstu_attention
auto silu = [&](CompDataType x) {
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
return x / (one + std::exp(-x)) / ck_tile::type_convert<CompDataType>(max_seqlen);
return x / (one + std::exp(-x));
};
auto f = [&](auto i_batch, auto i_head) {
@@ -199,6 +199,8 @@ struct reference_hstu_attention
};
};
dot_prod = dot_prod / ck_tile::type_convert<GemmAccDataType>(max_seqlen);
if constexpr(kIsJagged)
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
ck_tile::type_convert<InOutDataType>(dot_prod);