Convert P to fp16/bf16 before doing second gemm in reference hstu implementation

This commit is contained in:
Qianfeng Zhang
2025-05-29 01:04:19 +00:00
parent 36a0f2020c
commit 781cba355a

View File

@@ -182,20 +182,20 @@ struct reference_hstu_attention
{
if constexpr(kIsJagged)
{
GemmAccDataType preg =
ck_tile::type_convert<GemmAccDataType>(locals[sk]);
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
InOutDataType vreg =
v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
dot_prod += preg * ck_tile::type_convert<GemmAccDataType>(vreg);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
}
else
{
GemmAccDataType preg =
ck_tile::type_convert<GemmAccDataType>(locals[sk]);
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
dot_prod += preg * ck_tile::type_convert<GemmAccDataType>(vreg);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
};
};