From ae96e32faddf3331272d68fc073c1a26689fceb1 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer Date: Wed, 12 Mar 2025 16:04:48 +0000 Subject: [PATCH] Update non scaled reference kernel --- .../cpu/reference_gemm.hpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 7e2482807d..a5c42831be 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -79,6 +79,15 @@ struct ReferenceGemm : public device::BaseOperator i4 = i4 - 8; v_a = type_convert(i4); } + else if constexpr(is_same_v) + { + if(k % 2 == 1) + v_a = type_convert( + arg.a_m_k_(m, k).template unpack<>(Number<1>{})); + else + v_a = type_convert( + arg.a_m_k_(m, k).template unpack<>(Number<0>{})); + } else { arg.a_element_op_(v_a, arg.a_m_k_(m, k)); @@ -95,6 +104,15 @@ struct ReferenceGemm : public device::BaseOperator i4 = i4 - 8; v_b = type_convert(i4); } + else if constexpr(is_same_v) + { + if(k % 2 == 1) + v_b = type_convert( + arg.b_k_n_(k, n).template unpack<>(Number<1>{})); + else + v_b = type_convert( + arg.b_k_n_(k, n).template unpack<>(Number<0>{})); + } else { arg.b_element_op_(v_b, arg.b_k_n_(k, n));