From a1346aaf3ebb0f9873ccd1f025f1af696f05bfb0 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Tue, 20 May 2025 07:50:56 +0000 Subject: [PATCH] Update the reference hstu to not do fp32 to fp16/bf16 conversion before P@V gemm --- .../18_hstu_attention/reference_hstu_attention.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 2319a893e6..b5bb8bd811 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -182,20 +182,20 @@ struct reference_hstu_attention { if constexpr(kIsJagged) { - InOutDataType preg = ck_tile::type_convert(locals[sk]); + GemmAccDataType preg = + ck_tile::type_convert(locals[sk]); InOutDataType vreg = v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k); - dot_prod += ck_tile::type_convert(preg) * - ck_tile::type_convert(vreg); + dot_prod += preg * ck_tile::type_convert(vreg); } else { - InOutDataType preg = ck_tile::type_convert(locals[sk]); + GemmAccDataType preg = + ck_tile::type_convert(locals[sk]); InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k); - dot_prod += ck_tile::type_convert(preg) * - ck_tile::type_convert(vreg); + dot_prod += preg * ck_tile::type_convert(vreg); }; };