From 2576e64bf682a4ba2f5229502928b6f6f4686cdb Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer Date: Mon, 10 Mar 2025 15:46:51 +0000 Subject: [PATCH] Update reference mx gemm --- .../cpu/reference_mx_gemm.hpp | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index 649f130c41..5d4ced9486 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -89,9 +89,27 @@ struct ReferenceMXGemm : public device::BaseOperator { for(size_t k = 0; k < K; k++) { - a_m_k_scaled(m, k) = - type_convert(arg.a_m_k_(m, k)) * - type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + if constexpr(is_same_v) + { + if(k % 2 == 1) + a_m_k_scaled(m, k) = + type_convert( + arg.a_m_k_(m, k).template unpack<>(Number<1>{})) * + type_convert( + arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + else + a_m_k_scaled(m, k) = + type_convert( + arg.a_m_k_(m, k).template unpack<>(Number<0>{})) * + type_convert( + arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } + else + { + a_m_k_scaled(m, k) = + type_convert(arg.a_m_k_(m, k)) * + type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } } } @@ -99,9 +117,27 @@ struct ReferenceMXGemm : public device::BaseOperator { for(size_t k = 0; k < K; k++) { - b_k_n_scaled(k, n) = - type_convert(arg.b_k_n_(k, n)) * - type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + if constexpr(is_same_v) + { + if(k % 2 == 1) + b_k_n_scaled(k, n) = + type_convert( + arg.b_k_n_(k, n).template unpack<>(Number<1>{})) * + type_convert( + arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + else + b_k_n_scaled(k, n) = + type_convert( + arg.b_k_n_(k, n).template unpack<>(Number<0>{})) * + type_convert( + arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } + else + { + b_k_n_scaled(k, n) = + type_convert(arg.b_k_n_(k, n)) * + type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } } }