Update reference mx gemm

This commit is contained in:
Rostyslav Geyyer
2025-03-10 15:46:51 +00:00
parent ac0224532e
commit 2576e64bf6

View File

@@ -89,9 +89,27 @@ struct ReferenceMXGemm : public device::BaseOperator
{
for(size_t k = 0; k < K; k++)
{
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(arg.a_m_k_(m, k)) *
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
if constexpr(is_same_v<ADataType, f4x2_pk_t>)
{
if(k % 2 == 1)
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
arg.a_m_k_(m, k).template unpack<>(Number<1>{})) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
else
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
arg.a_m_k_(m, k).template unpack<>(Number<0>{})) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
}
else
{
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(arg.a_m_k_(m, k)) *
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
}
}
}
@@ -99,9 +117,27 @@ struct ReferenceMXGemm : public device::BaseOperator
{
for(size_t k = 0; k < K; k++)
{
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(arg.b_k_n_(k, n)) *
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
if constexpr(is_same_v<BDataType, f4x2_pk_t>)
{
if(k % 2 == 1)
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
arg.b_k_n_(k, n).template unpack<>(Number<1>{})) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
else
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
arg.b_k_n_(k, n).template unpack<>(Number<0>{})) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}
else
{
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(arg.b_k_n_(k, n)) *
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}
}
}