Update reference kernels

This commit is contained in:
Rostyslav Geyyer
2025-04-30 16:34:26 +00:00
parent b27154e689
commit b068584afb
2 changed files with 12 additions and 4 deletions

View File

@@ -81,6 +81,7 @@ struct ReferenceGemm : public device::BaseOperator
}
else if constexpr(is_same_v<ADataType, f4x2_pk_t>)
{
// TODO: add support for ColMajor layout as well
if(k % 2 == 1)
v_a = type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{})));
@@ -106,6 +107,7 @@ struct ReferenceGemm : public device::BaseOperator
}
else if constexpr(is_same_v<BDataType, f4x2_pk_t>)
{
// TODO: add support for RowMajor layout as well
if(k % 2 == 1)
v_b = type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{})));
@@ -120,6 +122,11 @@ struct ReferenceGemm : public device::BaseOperator
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
// if ((m == 2) && (n == 0))
// {
// printf("K:%i A:%f, B:%f, C:%f \n", k, v_a, v_b, v_acc);
// }
}
CDataType v_c{0};

View File

@@ -91,16 +91,17 @@ struct ReferenceMXGemm : public device::BaseOperator
{
if constexpr(is_same_v<ADataType, f4x2_pk_t>)
{
// TODO: add support for ColMajor layout as well
if(k % 2 == 1)
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
arg.a_m_k_(m, k).template unpack<>(Number<1>{})) *
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
else
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
arg.a_m_k_(m, k).template unpack<>(Number<0>{})) *
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
}
@@ -122,13 +123,13 @@ struct ReferenceMXGemm : public device::BaseOperator
if(k % 2 == 1)
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
arg.b_k_n_(k, n).template unpack<>(Number<1>{})) *
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
else
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
arg.b_k_n_(k, n).template unpack<>(Number<0>{})) *
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}