From 781cba355af76e3bead95fef30cfad4e92e38f8b Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 29 May 2025 01:04:19 +0000 Subject: [PATCH] Convert P to fp16/bf16 before doing second gemm in reference hstu implementation --- .../18_hstu_attention/reference_hstu_attention.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index b5bb8bd811..2319a893e6 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -182,20 +182,20 @@ struct reference_hstu_attention { if constexpr(kIsJagged) { - GemmAccDataType preg = - ck_tile::type_convert(locals[sk]); + InOutDataType preg = ck_tile::type_convert(locals[sk]); InOutDataType vreg = v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k); - dot_prod += preg * ck_tile::type_convert(vreg); + dot_prod += ck_tile::type_convert(preg) * + ck_tile::type_convert(vreg); } else { - GemmAccDataType preg = - ck_tile::type_convert(locals[sk]); + InOutDataType preg = ck_tile::type_convert(locals[sk]); InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k); - dot_prod += preg * ck_tile::type_convert(vreg); + dot_prod += ck_tile::type_convert(preg) * + ck_tile::type_convert(vreg); }; };