Merge commit '57e0f5df29abefd919c334c994628a994ba2868c' into develop

This commit is contained in:
assistant-librarian[bot]
2025-05-19 22:06:56 +00:00
parent 0b87df9c4a
commit 9d088bc569
15 changed files with 1602 additions and 588 deletions

View File

@@ -89,6 +89,14 @@ struct ReferenceGemm : public device::BaseOperator
v_a = type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{})));
}
else if constexpr(is_same_v<ADataType, f6x16_pk_t> ||
is_same_v<ADataType, bf6x16_pk_t> ||
is_same_v<ADataType, f6x32_pk_t> ||
is_same_v<ADataType, bf6x32_pk_t>)
{
v_a = type_convert<ComputeTypeA>(
arg.a_m_k_(m, k).unpack(k % ADataType::packed_size));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
@@ -115,6 +123,14 @@ struct ReferenceGemm : public device::BaseOperator
v_b = type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{})));
}
else if constexpr(is_same_v<BDataType, f6x16_pk_t> ||
is_same_v<BDataType, bf6x16_pk_t> ||
is_same_v<BDataType, f6x32_pk_t> ||
is_same_v<BDataType, bf6x32_pk_t>)
{
v_b = type_convert<ComputeTypeB>(
arg.b_k_n_(k, n).unpack(k % BDataType::packed_size));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));

View File

@@ -105,6 +105,16 @@ struct ReferenceMXGemm : public device::BaseOperator
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
}
else if constexpr(is_same_v<ADataType, f6x16_pk_t> ||
is_same_v<ADataType, bf6x16_pk_t> ||
is_same_v<ADataType, f6x32_pk_t> ||
is_same_v<ADataType, bf6x32_pk_t>)
{
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)) *
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
}
else
{
a_m_k_scaled(m, k) =
@@ -134,6 +144,16 @@ struct ReferenceMXGemm : public device::BaseOperator
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}
else if constexpr(is_same_v<BDataType, f6x16_pk_t> ||
is_same_v<BDataType, bf6x16_pk_t> ||
is_same_v<BDataType, f6x32_pk_t> ||
is_same_v<BDataType, bf6x32_pk_t>)
{
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)) *
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}
else
{
b_k_n_scaled(k, n) =