diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index cc0778de4c..304ce070ff 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -60,3 +60,4 @@ add_subdirectory(grouped_gemm) add_subdirectory(convnd_fwd) add_subdirectory(reduce) add_subdirectory(conv2d_bwd_weight) +add_subdirectory(cgemm) diff --git a/test/cgemm/CMakeLists.txt b/test/cgemm/CMakeLists.txt new file mode 100644 index 0000000000..15fa9725fd --- /dev/null +++ b/test/cgemm/CMakeLists.txt @@ -0,0 +1,11 @@ +add_test_executable(test_cgemm_fp32 cgemm_fp32.cpp) +target_link_libraries(test_cgemm_fp32 PRIVATE host_tensor) +target_link_libraries(test_cgemm_fp32 PRIVATE device_cgemm_instance) + +add_test_executable(test_cgemm_fp16 cgemm_fp16.cpp) +target_link_libraries(test_cgemm_fp16 PRIVATE host_tensor) +target_link_libraries(test_cgemm_fp16 PRIVATE device_cgemm_instance) + +add_test_executable(test_cgemm_bf16 cgemm_bf16.cpp) +target_link_libraries(test_cgemm_bf16 PRIVATE host_tensor) +target_link_libraries(test_cgemm_bf16 PRIVATE device_cgemm_instance) diff --git a/test/cgemm/cgemm_bf16.cpp b/test/cgemm/cgemm_bf16.cpp new file mode 100644 index 0000000000..e97b10d0b5 --- /dev/null +++ b/test/cgemm/cgemm_bf16.cpp @@ -0,0 +1,114 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "cgemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_cgemm_4gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_cgemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceCGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_cgemm_instance { +void add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector&); +void add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector&); +void add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector&); +void add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector&); +} // namespace device_cgemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + + ck::tensor_operation::device::device_gemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemmBF16{}(cgemmPtr); + } + + cgemmPtrs.clear(); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemmBF16{}(gemmPtr); + } + + cgemmPtrs.clear(); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemmBF16{}(cgemmPtr); + } + + cgemmPtrs.clear(); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemmBF16{}(cgemmPtr); + } + + std::cout << "TestCGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/cgemm/cgemm_fp16.cpp b/test/cgemm/cgemm_fp16.cpp new file mode 100644 index 0000000000..de8c0f4c77 --- /dev/null +++ b/test/cgemm/cgemm_fp16.cpp @@ -0,0 +1,144 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "cgemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "device_tensor.hpp" +#include "device_cgemm_4gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceCGemmNoOpPtr = + ck::tensor_operation::device::DevicecgemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_cgemm_instance { +void add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector&); +void add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector&); +void add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +void add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector&); +} // namespace device_cgemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = ck::half_t; + using BDataType = ck::half_t; + using CDataType = ck::half_t; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector cgemmPtrs; + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_f16_f16_f16_km_kn_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemm{}(cgemmPtr); + } + + cgemmPtrs.clear(); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_f16_f16_f16_km_nk_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemm{}(cgemmPtr); + } + + cgemmPtrs.clear(); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemm{}(cgemmPtr); + } + + cgemmPtrs.clear(); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemm{}(cgemmPtr); + } + + std::cout << "TestCGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/cgemm/cgemm_fp32.cpp b/test/cgemm/cgemm_fp32.cpp new file mode 100644 index 0000000000..33d3864c37 --- /dev/null +++ b/test/cgemm/cgemm_fp32.cpp @@ -0,0 +1,146 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "cgemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_cgemm_4gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_cgemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceCGemmNoOpPtr = + ck::tensor_operation::device::DevicecgemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_cgemm_instance { +void add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector&); +void add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector&); +void add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector&); +void add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector&); + +} // namespace device_cgemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = float; + using BDataType = float; + using CDataType = float; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector cgemmPtrs; + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_f32_f32_f32_km_kn_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemm{}(cgemmPtr); + } + + cgemmPtrs.clear(); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_f32_f32_f32_km_nk_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemm{}(cgemmPtr); + } + + cgemmPtrs.clear(); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemm{}(cgemmPtr); + } + + cgemmPtrs.clear(); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: + add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs); + + for(auto& cgemmPtr : cgemmPtrs) + { + res &= ck::cgemm_util::TestCGemm{}(cgemmPtr); + } + + std::cout << "TestCGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/cgemm/cgemm_util.hpp b/test/cgemm/cgemm_util.hpp new file mode 100644 index 0000000000..93c36ca3a5 --- /dev/null +++ b/test/cgemm/cgemm_util.hpp @@ -0,0 +1,455 @@ +#ifndef GEMM_UTILS_HPP +#define GEMM_UTILS_HPP + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_cgemm.hpp" +#include "tensor_layout.hpp" + +namespace ck { +namespace cgemm_util { + +struct CGemmParams +{ + CGemmParams() + : M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) + { + } + + ck::index_t M; + ck::index_t N; + ck::index_t K; + + ck::index_t StrideA; + ck::index_t StrideB; + ck::index_t StrideC; + + float alpha; + float beta; +}; + +template +void RunHostCGEMM(const Tensor& A_real, + const Tensor& A_imag, + const Tensor& B_real, + const Tensor& B_imag, + Tensor& C_real, + Tensor& C_imag, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + auto ref_cgemm = CGemmInstance{}; + auto ref_invoker = ref_cgemm.MakeInvoker(); + + auto ref_argument = ref_cgemm.MakeArgument( + A_real, A_imag, B_real, B_imag, C_real, C_imag, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); +} + +template +void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, + const ck::cgemm_util::CGemmParams& params, + const Tensor& A_real, + const Tensor& A_imag, + const Tensor& B_real, + const Tensor& B_imag, + Tensor& C_real, + Tensor& C_imag, + Tensor& Aux, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); + DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); + DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); + DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); + DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); + DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); + DeviceMem aux_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(A.mData.data()); + b_k_n_device_buf.ToDevice(B.mData.data()); + + auto invoker_ptr = cgemmPtr->MakeInvokerPointer(); + auto argument_ptr = cgemmPtr->MakeArgumentPointer( + static_cast(a_m_k_real_device_buf.GetDeviceBuffer()), + static_cast(a_m_k_real_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_real_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_imag_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_real_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_imag_device_buf.GetDeviceBuffer()), + static_cast(aux_device_buf.GetDeviceBuffer()), + params.M, + params.N, + params.K, + params.StrideA, + params.StrideB, + params.StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!cgemmPtr->IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error( + "wrong! device_cgemm with the specified compilation parameters does " + "not support this CGEMM problem"); + } + + invoker_ptr->Run(argument_ptr.get()); + c_m_n_real_device_buf.FromDevice(C_real.mData.data()); + c_m_n_imag_device_buf.FromDevice(C_imag.mData.data()); +} + +template +struct TestCGemm +{ + auto PrepareCGemmTensor(const ck::cgemm_util::CGemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k_real( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor a_m_k_imag( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_k_n_real( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor b_k_n_imag( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_real_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_imag_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_real_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_imag_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor aux( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + auto f_generate_tensor_value = [](auto& tensor, auto type) { + using dataType = decltype(type); + + tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }; + + f_generate_tensor_value(a_m_k_real, ADataType{}); + f_generate_tensor_value(a_m_k_imag, ADataType{}); + f_generate_tensor_value(b_k_n_real, BDataType{}); + f_generate_tensor_value(b_k_n_imag, BDataType{}); + + return std::make_tuple(a_m_k_real, + a_m_k_imag, + b_k_n_real, + b_k_n_imag, + c_m_n_real_host_result, + c_m_n_imag_host_result, + c_m_n_real_device_result, + c_m_n_imag_device_result, + aux); + } + + auto operator()(DeviceCGemmPtr_& cgemmPtr) + { + std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name + << ", CLayout = " << CLayout{}.name << std::endl; + std::cout << cgemmPtr->GetTypeString() << std::endl; + + // Arrange + ck::cgemm_util::CGemmParams params; + params.M = 1024; + params.N = 1024; + params.K = 1024; + params.StrideA = 1024; + params.StrideB = 1024; + params.StrideC = 1024; + + auto host_tensors = PrepareCGemmTensor(params); + + const Tensor& a_real = std::get<0>(host_tensors); + const Tensor& a_imag = std::get<1>(host_tensors); + const Tensor& b_real = std::get<2>(host_tensors); + const Tensor& b_imag = std::get<3>(host_tensors); + Tensor& c_host_real = std::get<4>(host_tensors); + Tensor& c_host_imag = std::get<5>(host_tensors); + Tensor& c_device_real = std::get<6>(host_tensors); + Tensor& c_device_imag = std::get<7>(host_tensors); + Tensor& aux = std::get<8>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceCGemm; + ck::cgemm_util::RunHostCGEMM(a_real, + a_imag, + b_real, + b_imag, + c_host_real, + c_host_imag, + a_element_op, + b_element_op, + c_element_op); + + // Act + ck::cgemm_util::RunDeviceCGEMM(cgemmPtr, + params, + a_real, + a_imag, + b_real, + b_imag, + c_device_real, + c_device_imag, + aux, + a_element_op, + b_element_op, + c_element_op); + + // Assert + bool res = false; + if(std::is_same::value) + { + res = ck::utils::check_err(c_device_real.mData, c_host_real.mData) && + ck::utils::check_err(c_device_real.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device_real.mData, c_host_real.mData) && + ck::utils::check_err(c_device_imag.mData, c_host_imag.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device_real.mData, c_host_real.mData) && + ck::utils::check_err(c_device_imag.mData, c_host_imag.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + + return res; + } +}; + +template +struct TestCGemmBF16 +{ + using BF16 = ck::bhalf_t; + + auto PrepareCGemmTensorBF16(const ck::cgemm_util::CGemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + // use fp32 host kernel to verify bf16 device kernel + Tensor a_m_k_real_bf16( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor a_m_k_imag_bf16( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_k_n_real_bf16( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor b_k_n_imag_bf16( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_real_device_bf16( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_imag_device_bf16( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor aux_bf16( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + Tensor a_m_k_real_fp32( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor a_m_k_imag_fp32( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_k_n_real_fp32( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor b_k_n_imag_fp32( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_real_fp32( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_host_imag_fp32( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_real_fp32( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_imag_fp32( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor aux_fp32( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + a_m_k_real_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_m_k_imag_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n_real_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n_imag_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + + bf16_to_f32_(a_m_k_real_bf16, a_m_k_real_fp32); + bf16_to_f32_(a_m_k_imag_bf16, a_m_k_imag_fp32); + bf16_to_f32_(b_k_n_real_bf16, b_k_n_imag_fp32); + bf16_to_f32_(b_k_n_real_bf16, b_k_n_imag_fp32); + + return std::make_tuple(a_m_k_real_bf16, + a_m_k_imag_bf16, + b_k_n_real_bf16, + b_k_n_imag_bf16, + c_m_n_real_device_bf16, + c_m_n_imag_device_bf16, + aux_bf16, + a_m_k_real_fp32, + a_m_k_imag_fp32, + b_k_n_real_fp32, + b_k_n_imag_fp32, + c_m_n_real_host_fp32, + c_m_n_imag_host_fp32, + c_m_n_real_device_fp32, + c_m_n_imag_device_fp32, + aux_fp32); + } + + auto operator()(DeviceCGemmPtr_& cgemmPtr) + { + // Arrange + ck::cgemm_util::CGemmParams params; + params.M = 1024; + params.N = 1024; + params.K = 1024; + params.StrideA = 1024; + params.StrideB = 1024; + params.StrideC = 1024; + + auto host_tensors = PrepareCGemmTensorBF16(params); + const Tensor& a_real_bf16 = std::get<0>(host_tensors); + const Tensor& a_imag_bf16 = std::get<1>(host_tensors); + const Tensor& b_real_bf16 = std::get<2>(host_tensors); + const Tensor& b_imag_bf16 = std::get<3>(host_tensors); + Tensor& c_real_device_bf16 = std::get<4>(host_tensors); + Tensor& c_imag_device_bf16 = std::get<5>(host_tensors); + Tensor& aux_bf16 = std::get<6>(host_tensors); + Tensor& a_real_fp32 = std::get<7>(host_tensors); + Tensor& a_imag_fp32 = std::get<8>(host_tensors); + Tensor& b_real_fp32 = std::get<9>(host_tensors); + Tensor& b_imag_fp32 = std::get<10>(host_tensors); + Tensor& c_real_host_fp32 = std::get<11>(host_tensors); + Tensor& c_imag_host_fp32 = std::get<12>(host_tensors); + Tensor& c_real_device_fp32 = std::get<13>(host_tensors); + Tensor& c_imag_device_fp32 = std::get<14>(host_tensors); + Tensor& aux_fp32 = std::get<15>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + // use fp32 host kernel to verify bf16 device kernel + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceCGemm; + ck::gemm_util::RunHostCGEMM(a_real_fp32, + a_imag_fp32, + b_real_fp32, + b_imag_fp32, + c_real_host_fp32, + c_imag_fp32, + a_element_op, + b_element_op, + c_element_op); + + // Act + ck::gemm_util::RunDeviceCGEMM(cgemmPtr, + params, + a_real_bf16, + a_imag_bf16, + b_real_bf16, + b_imag_bf16, + c_real_device_bf16, + c_imag_device_bf16, + aux_bf16, + a_element_op, + b_element_op, + c_element_op); + + bf16_to_f32_(c_real_device_bf16, c_real_device_fp32); + bf16_to_f32_(c_imag_device_bf16, c_imag_device_fp32); + + // Assert + bool res = ck::utils::check_err(c_real_device_fp32.mData, + c_real_host_fp32.mData, + "Error: incorrect results!", + 1e-2f, + 1e-3f) && + ck::utils::check_err(c_imag_device_fp32.mData, + c_imag_host_fp32.mData, + "Error: incorrect results!", + 1e-2f, + 1e-3f); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + + return res; + }; +}; + +} // namespace cgemm_util +} // namespace ck +#endif