mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Convert P to fp16/bf16 before doing second gemm in reference hstu implementation
This commit is contained in:
@@ -182,20 +182,20 @@ struct reference_hstu_attention
|
||||
{
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
GemmAccDataType preg =
|
||||
ck_tile::type_convert<GemmAccDataType>(locals[sk]);
|
||||
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
|
||||
InOutDataType vreg =
|
||||
v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
|
||||
|
||||
dot_prod += preg * ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
}
|
||||
else
|
||||
{
|
||||
GemmAccDataType preg =
|
||||
ck_tile::type_convert<GemmAccDataType>(locals[sk]);
|
||||
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
|
||||
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
|
||||
|
||||
dot_prod += preg * ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user