From 37bfa01c0dce53959ea05abf31a9802ecca66c48 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Fri, 7 Feb 2025 09:03:00 -0600 Subject: [PATCH] Add a host mx gemm reference kernel (#1864) * Add mx gemm reference kernel * Update copyright year * Update mx gemm example * Use element-wise ops in the reference gemm --- .../67_gemm_microscaling/gemm_mx_common.hpp | 56 +++--- .../cpu/reference_mx_gemm.hpp | 178 ++++++++++++++++++ 2 files changed, 200 insertions(+), 34 deletions(-) create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 7ba7d4768b..5b00b5a123 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -13,7 +13,7 @@ #include "ck/utility/blkgemmpipe_scheduler.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/sequence.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/fill.hpp" @@ -315,40 +315,27 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) std::cout << "Computing GEMM on host..." << std::endl; } - Tensor c({M, N}); - Tensor a({M, K}); - Tensor b({K, N}); - - for(int m = 0; m < M; m++) - { - for(int k = 0; k < K; k++) - { - a(m, k) = ck::type_convert(a_m_k(m, k)) * - ck::type_convert(a_m_k_scale(m, k / Scale_Block_K)); - } - } - - for(int n = 0; n < N; n++) - { - for(int k = 0; k < K; k++) - { - b(k, n) = ck::type_convert(b_k_n(k, n)) * - ck::type_convert(b_k_n_scale(k / Scale_Block_K, n)); - } - } - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm; auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = - ref_gemm.MakeArgument(a, b, c, PassThrough{}, PassThrough{}, PassThrough{}); + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + a_m_k_scale, + b_k_n, + b_k_n_scale, + c_m_n_host_result, + PassThrough{}, + PassThrough{}, + PassThrough{}); ref_invoker.Run(ref_argument); @@ -366,8 +353,9 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl; } - res_verified = res_verified && - ck::utils::check_err(c_m_n_device_result, c, "Error: Incorrect results!"); + res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!"); if(config.verbosity > 0 && res_verified) std::cout << "Done." << std::endl; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp new file mode 100644 index 0000000000..649f130c41 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceMXGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& a_m_kblock_scales, + const Tensor& b_k_n, + const Tensor& b_kblock_n_scales, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + a_m_kblock_scales_{a_m_kblock_scales}, + b_k_n_{b_k_n}, + b_kblock_n_scales_{b_kblock_n_scales}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& a_m_kblock_scales_; + const Tensor& b_k_n_; + const Tensor& b_kblock_n_scales_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceMXGemm::Argument; + + float Run(const Argument& arg) + { + using GemmInstance = ck::tensor_operation::host::ReferenceGemm; + + Tensor a_m_k_scaled(arg.a_m_k_.mDesc); + Tensor b_k_n_scaled(arg.b_k_n_.mDesc); + + const auto M = arg.a_m_k_.mDesc.GetLengths()[0]; + const auto N = arg.b_k_n_.mDesc.GetLengths()[1]; + const auto K = arg.a_m_k_.mDesc.GetLengths()[1]; + const auto SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1]; + + for(size_t m = 0; m < M; m++) + { + for(size_t k = 0; k < K; k++) + { + a_m_k_scaled(m, k) = + type_convert(arg.a_m_k_(m, k)) * + type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } + } + + for(size_t n = 0; n < N; n++) + { + for(size_t k = 0; k < K; k++) + { + b_k_n_scaled(k, n) = + type_convert(arg.b_k_n_(k, n)) * + type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } + } + + auto ref_gemm = GemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(a_m_k_scaled, + b_k_n_scaled, + arg.c_m_n_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + + ref_invoker.Run(ref_argument); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& a_m_kblock_scales, + const Tensor& b_k_n, + const Tensor& b_kblock_n_scales, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, + a_m_kblock_scales, + b_k_n, + b_kblock_n_scales, + c_m_n, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMXGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck