From 0c2d00df8a6dbee36d9f79b84b534ecc8154bc0e Mon Sep 17 00:00:00 2001 From: myamlak Date: Wed, 11 May 2022 08:25:01 +0000 Subject: [PATCH] Reference CGEMM + test stub --- example/19_cgemm/CMakeLists.txt | 1 + example/19_cgemm/cgemm_xdl_bf16.cpp | 23 ++ example/CMakeLists.txt | 3 +- .../cpu/reference_cgemm.hpp | 197 ++++++++++++++++++ 4 files changed, 223 insertions(+), 1 deletion(-) create mode 100644 example/19_cgemm/CMakeLists.txt create mode 100644 example/19_cgemm/cgemm_xdl_bf16.cpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp diff --git a/example/19_cgemm/CMakeLists.txt b/example/19_cgemm/CMakeLists.txt new file mode 100644 index 0000000000..36b0b79b1a --- /dev/null +++ b/example/19_cgemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) diff --git a/example/19_cgemm/cgemm_xdl_bf16.cpp b/example/19_cgemm/cgemm_xdl_bf16.cpp new file mode 100644 index 0000000000..739323dd52 --- /dev/null +++ b/example/19_cgemm/cgemm_xdl_bf16.cpp @@ -0,0 +1,23 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_cgemm.hpp" +#include "gemm_specialization.hpp" + +// stub only +int main() +{ + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 5f04125305..5ea3889844 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -38,7 +38,8 @@ add_subdirectory(11_conv2d_bwd_weight) add_subdirectory(12_reduce) add_subdirectory(13_pool2d_fwd) add_subdirectory(14_gemm_xdl_requant_relu_requant) -add_subdirectory(17_convnd_bwd_data_xdl) add_subdirectory(15_grouped_gemm) add_subdirectory(16_gemm_reduce) +add_subdirectory(17_convnd_bwd_data_xdl) add_subdirectory(18_batched_gemm_reduce) +add_subdirectory(19_cgemm) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp new file mode 100644 index 0000000000..aa4addab23 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp @@ -0,0 +1,197 @@ +#pragma once +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceCGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k_real, + const Tensor& a_m_k_imag, + const Tensor& b_k_n_real, + const Tensor& b_k_n_imag, + Tensor& c_m_n_real, + Tensor& c_m_n_imag, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_real_{a_m_k_real}, + a_m_k_imag_{a_m_k_imag}, + b_k_n_real_{b_k_n_real}, + b_k_n_imag_{b_k_n_imag}, + c_m_n_real_{c_m_n_real}, + c_m_n_imag_{c_m_n_imag}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_real_; + const Tensor& a_m_k_imag_; + const Tensor& b_k_n_real_; + const Tensor& b_k_n_imag_; + Tensor& c_m_n_real_; + Tensor& c_m_n_imag_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceCGemm::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn_real = [&](auto m, auto n) { + const int K = arg.a_m_k_real_.mDesc.GetLengths()[1]; + + if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1]) + { + throw std::runtime_error( + "wrong! Incompatible real and imag sizes in CGEMM"); + } + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a_real; + float v_b_real; + float v_a_imag; + float v_b_imag; + + arg.a_element_op_(v_a_real, static_cast(arg.a_m_k_real_(m, k))); + arg.a_element_op_(v_a_imag, static_cast(arg.a_m_k_imag_(m, k))); + arg.b_element_op_(v_b_real, static_cast(arg.b_k_n_real_(k, n))); + arg.b_element_op_(v_b_imag, static_cast(arg.b_k_n_imag_(k, n))); + + v_acc += v_a_real * v_b_real - v_a_imag * v_b_imag; + } + + float v_c_real; + + arg.c_element_op_(v_c_real, v_acc); + + arg.c_m_n_real_(m, n) = v_c_real; + }; + + auto f_mk_kn_mn_imag = [&](auto m, auto n) { + const int K = arg.a_m_k_real_.mDesc.GetLengths()[1]; + + if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1]) + { + throw std::runtime_error( + "wrong! Incompatible real and imag sizes in CGEMM"); + } + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a_real; + float v_b_real; + float v_a_imag; + float v_b_imag; + + arg.a_element_op_(v_a_real, static_cast(arg.a_m_k_real_(m, k))); + arg.a_element_op_(v_a_imag, static_cast(arg.a_m_k_imag_(m, k))); + arg.b_element_op_(v_b_real, static_cast(arg.b_k_n_real_(k, n))); + arg.b_element_op_(v_b_imag, static_cast(arg.b_k_n_imag_(k, n))); + + v_acc += v_a_real * v_b_imag - v_a_imag * v_b_real; + } + + float v_c_imag; + + arg.c_element_op_(v_c_imag, v_acc); + + arg.c_m_n_imag_(m, n) = v_c_imag; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn_real, + arg.c_m_n_real_.mDesc.GetLengths()[0], + arg.c_m_n_real_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_mk_kn_mn_imag, + arg.c_m_n_imag_.mDesc.GetLengths()[0], + arg.c_m_n_imag_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, int) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k_real, + const Tensor& a_m_k_imag, + const Tensor& b_k_n_real, + const Tensor& b_k_n_imag, + Tensor& c_m_n_real, + Tensor& c_m_n_imag, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k_real, + a_m_k_imag, + b_k_n_real, + b_k_n_imag, + c_m_n_real, + c_m_n_imag, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceCGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck