diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 7e2482807d..a5c42831be 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -79,6 +79,15 @@ struct ReferenceGemm : public device::BaseOperator i4 = i4 - 8; v_a = type_convert(i4); } + else if constexpr(is_same_v) + { + if(k % 2 == 1) + v_a = type_convert( + arg.a_m_k_(m, k).template unpack<>(Number<1>{})); + else + v_a = type_convert( + arg.a_m_k_(m, k).template unpack<>(Number<0>{})); + } else { arg.a_element_op_(v_a, arg.a_m_k_(m, k)); @@ -95,6 +104,15 @@ struct ReferenceGemm : public device::BaseOperator i4 = i4 - 8; v_b = type_convert(i4); } + else if constexpr(is_same_v) + { + if(k % 2 == 1) + v_b = type_convert( + arg.b_k_n_(k, n).template unpack<>(Number<1>{})); + else + v_b = type_convert( + arg.b_k_n_(k, n).template unpack<>(Number<0>{})); + } else { arg.b_element_op_(v_b, arg.b_k_n_(k, n));