mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Move dividing by max_seqlen to end of Gemm1 loop in the reference hstu-attention codes
This commit is contained in:
@@ -199,8 +199,8 @@ auto get_elimit()
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 2e-3;
|
||||
double atol = 2e-3;
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ struct reference_hstu_attention
|
||||
auto silu = [&](CompDataType x) {
|
||||
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
return x / (one + std::exp(-x)) / ck_tile::type_convert<CompDataType>(max_seqlen);
|
||||
return x / (one + std::exp(-x));
|
||||
};
|
||||
|
||||
auto f = [&](auto i_batch, auto i_head) {
|
||||
@@ -199,6 +199,8 @@ struct reference_hstu_attention
|
||||
};
|
||||
};
|
||||
|
||||
dot_prod = dot_prod / ck_tile::type_convert<GemmAccDataType>(max_seqlen);
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
|
||||
ck_tile::type_convert<InOutDataType>(dot_prod);
|
||||
|
||||
Reference in New Issue
Block a user