diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 2319a893e6..b5bb8bd811 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -182,20 +182,20 @@ struct reference_hstu_attention { if constexpr(kIsJagged) { - InOutDataType preg = ck_tile::type_convert(locals[sk]); + GemmAccDataType preg = + ck_tile::type_convert(locals[sk]); InOutDataType vreg = v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k); - dot_prod += ck_tile::type_convert(preg) * - ck_tile::type_convert(vreg); + dot_prod += preg * ck_tile::type_convert(vreg); } else { - InOutDataType preg = ck_tile::type_convert(locals[sk]); + GemmAccDataType preg = + ck_tile::type_convert(locals[sk]); InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k); - dot_prod += ck_tile::type_convert(preg) * - ck_tile::type_convert(vreg); + dot_prod += preg * ck_tile::type_convert(vreg); }; };