diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index 649f130c41..5d4ced9486 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -89,9 +89,27 @@ struct ReferenceMXGemm : public device::BaseOperator { for(size_t k = 0; k < K; k++) { - a_m_k_scaled(m, k) = - type_convert(arg.a_m_k_(m, k)) * - type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + if constexpr(is_same_v) + { + if(k % 2 == 1) + a_m_k_scaled(m, k) = + type_convert( + arg.a_m_k_(m, k).template unpack<>(Number<1>{})) * + type_convert( + arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + else + a_m_k_scaled(m, k) = + type_convert( + arg.a_m_k_(m, k).template unpack<>(Number<0>{})) * + type_convert( + arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } + else + { + a_m_k_scaled(m, k) = + type_convert(arg.a_m_k_(m, k)) * + type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } } } @@ -99,9 +117,27 @@ struct ReferenceMXGemm : public device::BaseOperator { for(size_t k = 0; k < K; k++) { - b_k_n_scaled(k, n) = - type_convert(arg.b_k_n_(k, n)) * - type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + if constexpr(is_same_v) + { + if(k % 2 == 1) + b_k_n_scaled(k, n) = + type_convert( + arg.b_k_n_(k, n).template unpack<>(Number<1>{})) * + type_convert( + arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + else + b_k_n_scaled(k, n) = + type_convert( + arg.b_k_n_(k, n).template unpack<>(Number<0>{})) * + type_convert( + arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } + else + { + b_k_n_scaled(k, n) = + type_convert(arg.b_k_n_(k, n)) * + type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } } }