Update the reference hstu to not do fp32 to fp16/bf16 conversion before P@V gemm

This commit is contained in:
Qianfeng Zhang
2025-05-20 07:50:56 +00:00
parent 0a8ea6bd02
commit a1346aaf3e

View File

@@ -182,20 +182,20 @@ struct reference_hstu_attention
{
if constexpr(kIsJagged)
{
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
GemmAccDataType preg =
ck_tile::type_convert<GemmAccDataType>(locals[sk]);
InOutDataType vreg =
v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
dot_prod += preg * ck_tile::type_convert<GemmAccDataType>(vreg);
}
else
{
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
GemmAccDataType preg =
ck_tile::type_convert<GemmAccDataType>(locals[sk]);
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
dot_prod += preg * ck_tile::type_convert<GemmAccDataType>(vreg);
};
};