Update non scaled reference kernel

This commit is contained in:
Rostyslav Geyyer
2025-03-12 16:04:48 +00:00
parent 489961c19a
commit ae96e32fad

View File

@@ -79,6 +79,15 @@ struct ReferenceGemm : public device::BaseOperator
i4 = i4 - 8;
v_a = type_convert<ComputeTypeA>(i4);
}
else if constexpr(is_same_v<ADataType, f4x2_pk_t>)
{
if(k % 2 == 1)
v_a = type_convert<ComputeTypeA>(
arg.a_m_k_(m, k).template unpack<>(Number<1>{}));
else
v_a = type_convert<ComputeTypeA>(
arg.a_m_k_(m, k).template unpack<>(Number<0>{}));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
@@ -95,6 +104,15 @@ struct ReferenceGemm : public device::BaseOperator
i4 = i4 - 8;
v_b = type_convert<ComputeTypeB>(i4);
}
else if constexpr(is_same_v<BDataType, f4x2_pk_t>)
{
if(k % 2 == 1)
v_b = type_convert<ComputeTypeB>(
arg.b_k_n_(k, n).template unpack<>(Number<1>{}));
else
v_b = type_convert<ComputeTypeB>(
arg.b_k_n_(k, n).template unpack<>(Number<0>{}));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));