diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index b5bb8bd811..2319a893e6 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -182,20 +182,20 @@ struct reference_hstu_attention { if constexpr(kIsJagged) { - GemmAccDataType preg = - ck_tile::type_convert(locals[sk]); + InOutDataType preg = ck_tile::type_convert(locals[sk]); InOutDataType vreg = v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k); - dot_prod += preg * ck_tile::type_convert(vreg); + dot_prod += ck_tile::type_convert(preg) * + ck_tile::type_convert(vreg); } else { - GemmAccDataType preg = - ck_tile::type_convert(locals[sk]); + InOutDataType preg = ck_tile::type_convert(locals[sk]); InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k); - dot_prod += preg * ck_tile::type_convert(vreg); + dot_prod += ck_tile::type_convert(preg) * + ck_tile::type_convert(vreg); }; };