mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Update reference kernels
This commit is contained in:
@@ -81,6 +81,7 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for ColMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
v_a = type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{})));
|
||||
@@ -106,6 +107,7 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for RowMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
v_b = type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{})));
|
||||
@@ -120,6 +122,11 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
|
||||
// if ((m == 2) && (n == 0))
|
||||
// {
|
||||
// printf("K:%i A:%f, B:%f, C:%f \n", k, v_a, v_b, v_acc);
|
||||
// }
|
||||
}
|
||||
|
||||
CDataType v_c{0};
|
||||
|
||||
@@ -91,16 +91,17 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
{
|
||||
if constexpr(is_same_v<ADataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for ColMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_k_(m, k).template unpack<>(Number<1>{})) *
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) *
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
else
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_k_(m, k).template unpack<>(Number<0>{})) *
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) *
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
}
|
||||
@@ -122,13 +123,13 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
if(k % 2 == 1)
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_k_n_(k, n).template unpack<>(Number<1>{})) *
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) *
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
else
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_k_n_(k, n).template unpack<>(Number<0>{})) *
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) *
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user