diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index e80e048a42..1029dcec2f 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -81,6 +81,7 @@ struct ReferenceGemm : public device::BaseOperator } else if constexpr(is_same_v) { + // TODO: add support for ColMajor layout as well if(k % 2 == 1) v_a = type_convert( f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))); @@ -106,6 +107,7 @@ struct ReferenceGemm : public device::BaseOperator } else if constexpr(is_same_v) { + // TODO: add support for RowMajor layout as well if(k % 2 == 1) v_b = type_convert( f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))); @@ -120,6 +122,11 @@ struct ReferenceGemm : public device::BaseOperator v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); + + // if ((m == 2) && (n == 0)) + // { + // printf("K:%i A:%f, B:%f, C:%f \n", k, v_a, v_b, v_acc); + // } } CDataType v_c{0}; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index 5d4ced9486..e0697e3360 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -91,16 +91,17 @@ struct ReferenceMXGemm : public device::BaseOperator { if constexpr(is_same_v) { + // TODO: add support for ColMajor layout as well if(k % 2 == 1) a_m_k_scaled(m, k) = type_convert( - arg.a_m_k_(m, k).template unpack<>(Number<1>{})) * + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) * type_convert( arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); else a_m_k_scaled(m, k) = type_convert( - arg.a_m_k_(m, k).template unpack<>(Number<0>{})) * + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) * type_convert( arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); } @@ -122,13 +123,13 @@ struct ReferenceMXGemm : public device::BaseOperator if(k % 2 == 1) b_k_n_scaled(k, n) = type_convert( - arg.b_k_n_(k, n).template unpack<>(Number<1>{})) * + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) * type_convert( arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); else b_k_n_scaled(k, n) = type_convert( - arg.b_k_n_(k, n).template unpack<>(Number<0>{})) * + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) * type_convert( arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); }