From 674f74ad5fb5dcf59e40764949dc9e950ef1f459 Mon Sep 17 00:00:00 2001 From: myamlak Date: Mon, 16 May 2022 08:11:22 +0000 Subject: [PATCH] Test fixes. --- .../cpu/reference_cgemm.hpp | 8 +- test/cgemm/cgemm_bf16.cpp | 12 +-- test/cgemm/cgemm_fp16.cpp | 21 +---- test/cgemm/cgemm_fp32.cpp | 19 +---- test/cgemm/cgemm_util.hpp | 82 +++++++++---------- 5 files changed, 54 insertions(+), 88 deletions(-) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp index 4f4dcafdb7..c55b86aea7 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp @@ -60,7 +60,7 @@ struct ReferenceCGemm : public device::BaseOperator float Run(const Argument& arg) { auto f_mk_kn_mn_real = [&](auto m, auto n) { - const int K = arg.a_m_k_real_.mDesc.GetLengths()[1]; + const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1]; if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1]) { @@ -69,7 +69,7 @@ struct ReferenceCGemm : public device::BaseOperator float v_acc = 0; - for(int k = 0; k < K; ++k) + for(std::size_t k = 0; k < K; ++k) { float v_a_real; float v_b_real; @@ -92,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator }; auto f_mk_kn_mn_imag = [&](auto m, auto n) { - const int K = arg.a_m_k_real_.mDesc.GetLengths()[1]; + const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1]; if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1]) { @@ -101,7 +101,7 @@ struct ReferenceCGemm : public device::BaseOperator float v_acc = 0; - for(int k = 0; k < K; ++k) + for(std::size_t k = 0; k < K; ++k) { float v_a_real; float v_b_real; diff --git a/test/cgemm/cgemm_bf16.cpp b/test/cgemm/cgemm_bf16.cpp index e97b10d0b5..3e8d7d3fa9 100644 --- a/test/cgemm/cgemm_bf16.cpp +++ b/test/cgemm/cgemm_bf16.cpp @@ -21,9 +21,9 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceCGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; + ck::tensor_operation::device::DeviceCGemmPtr; namespace ck { namespace tensor_operation { @@ -48,9 +48,9 @@ int main() using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; bool res = true; - std::vector gemmPtrs; + std::vector cgemmPtrs; - ck::tensor_operation::device::device_gemm_instance:: + ck::tensor_operation::device::device_cgemm_instance:: add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(cgemmPtrs); for(auto& cgemmPtr : cgemmPtrs) @@ -76,7 +76,7 @@ int main() RowMajor, PassThrough, PassThrough, - PassThrough>{}(gemmPtr); + PassThrough>{}(cgemmPtr); } cgemmPtrs.clear(); diff --git a/test/cgemm/cgemm_fp16.cpp b/test/cgemm/cgemm_fp16.cpp index de8c0f4c77..5818bb9694 100644 --- a/test/cgemm/cgemm_fp16.cpp +++ b/test/cgemm/cgemm_fp16.cpp @@ -18,7 +18,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceCGemmNoOpPtr = - ck::tensor_operation::device::DevicecgemmPtr; @@ -50,10 +50,7 @@ int main() bool res = true; std::vector cgemmPtrs; - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_f16_f16_f16_km_kn_mn_instances(cgemmPtrs); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(cgemmPtrs); @@ -72,10 +69,6 @@ int main() } cgemmPtrs.clear(); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_f16_f16_f16_km_nk_mn_instances(cgemmPtrs); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(cgemmPtrs); ck::tensor_operation::device::device_cgemm_instance:: add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(cgemmPtrs); @@ -94,10 +87,6 @@ int main() } cgemmPtrs.clear(); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs); ck::tensor_operation::device::device_cgemm_instance:: add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs); @@ -116,14 +105,8 @@ int main() } cgemmPtrs.clear(); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs); ck::tensor_operation::device::device_cgemm_instance:: add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs); for(auto& cgemmPtr : cgemmPtrs) { diff --git a/test/cgemm/cgemm_fp32.cpp b/test/cgemm/cgemm_fp32.cpp index 33d3864c37..8b9e37238f 100644 --- a/test/cgemm/cgemm_fp32.cpp +++ b/test/cgemm/cgemm_fp32.cpp @@ -21,7 +21,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceCGemmNoOpPtr = - ck::tensor_operation::device::DevicecgemmPtr; @@ -54,10 +54,7 @@ int main() bool res = true; std::vector cgemmPtrs; - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_f32_f32_f32_km_kn_mn_instances(cgemmPtrs); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(cgemmPtrs); + ck::tensor_operation::device::device_cgemm_instance:: add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(cgemmPtrs); @@ -76,10 +73,6 @@ int main() } cgemmPtrs.clear(); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_f32_f32_f32_km_nk_mn_instances(cgemmPtrs); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(cgemmPtrs); ck::tensor_operation::device::device_cgemm_instance:: add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(cgemmPtrs); @@ -98,10 +91,6 @@ int main() } cgemmPtrs.clear(); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs); ck::tensor_operation::device::device_cgemm_instance:: add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs); @@ -120,10 +109,6 @@ int main() } cgemmPtrs.clear(); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs); - ck::tensor_operation::device::device_cgemm_instance:: - add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs); ck::tensor_operation::device::device_cgemm_instance:: add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs); diff --git a/test/cgemm/cgemm_util.hpp b/test/cgemm/cgemm_util.hpp index 93c36ca3a5..f45405b275 100644 --- a/test/cgemm/cgemm_util.hpp +++ b/test/cgemm/cgemm_util.hpp @@ -77,21 +77,23 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { - DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); - DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); - DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); - DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); - DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); - DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); - DeviceMem aux_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); + DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * A_real.mDesc.GetElementSpace()); + DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * A_imag.mDesc.GetElementSpace()); + DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * B_real.mDesc.GetElementSpace()); + DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * B_imag.mDesc.GetElementSpace()); + DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C_real.mDesc.GetElementSpace()); + DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C_imag.mDesc.GetElementSpace()); + DeviceMem aux_device_buf(sizeof(CDataType) * Aux.mDesc.GetElementSpace()); - a_m_k_device_buf.ToDevice(A.mData.data()); - b_k_n_device_buf.ToDevice(B.mData.data()); + a_m_k_real_device_buf.ToDevice(A_real.mData.data()); + a_m_k_imag_device_buf.ToDevice(A_imag.mData.data()); + b_k_n_real_device_buf.ToDevice(B_real.mData.data()); + b_k_n_imag_device_buf.ToDevice(B_imag.mData.data()); auto invoker_ptr = cgemmPtr->MakeInvokerPointer(); auto argument_ptr = cgemmPtr->MakeArgumentPointer( static_cast(a_m_k_real_device_buf.GetDeviceBuffer()), - static_cast(a_m_k_real_device_buf.GetDeviceBuffer()), + static_cast(a_m_k_imag_device_buf.GetDeviceBuffer()), static_cast(b_k_n_real_device_buf.GetDeviceBuffer()), static_cast(b_k_n_imag_device_buf.GetDeviceBuffer()), static_cast(c_m_n_real_device_buf.GetDeviceBuffer()), @@ -255,7 +257,7 @@ struct TestCGemm if(std::is_same::value) { res = ck::utils::check_err(c_device_real.mData, c_host_real.mData) && - ck::utils::check_err(c_device_real.mData, c_host.mData); + ck::utils::check_err(c_device_imag.mData, c_host_imag.mData); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } else if(std::is_same::value) @@ -326,15 +328,13 @@ struct TestCGemmBF16 f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor b_k_n_imag_fp32( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); - Tensor c_m_n_host_real_fp32( + Tensor c_m_n_real_host_fp32( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - Tensor c_m_n_host_imag_fp32( + Tensor c_m_n_imag_host_fp32( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - Tensor c_m_n_device_real_fp32( + Tensor c_m_n_real_device_fp32( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - Tensor c_m_n_device_imag_fp32( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - Tensor aux_fp32( + Tensor c_m_n_imag_device_fp32( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); a_m_k_real_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); @@ -361,8 +361,7 @@ struct TestCGemmBF16 c_m_n_real_host_fp32, c_m_n_imag_host_fp32, c_m_n_real_device_fp32, - c_m_n_imag_device_fp32, - aux_fp32); + c_m_n_imag_device_fp32); } auto operator()(DeviceCGemmPtr_& cgemmPtr) @@ -392,43 +391,42 @@ struct TestCGemmBF16 Tensor& c_imag_host_fp32 = std::get<12>(host_tensors); Tensor& c_real_device_fp32 = std::get<13>(host_tensors); Tensor& c_imag_device_fp32 = std::get<14>(host_tensors); - Tensor& aux_fp32 = std::get<15>(host_tensors); auto a_element_op = AElementwiseOperation{}; auto b_element_op = BElementwiseOperation{}; auto c_element_op = CElementwiseOperation{}; // use fp32 host kernel to verify bf16 device kernel - using ReferenceGemmInstance = + using ReferenceCGemmInstance = ck::tensor_operation::host::ReferenceCGemm; - ck::gemm_util::RunHostCGEMM(a_real_fp32, - a_imag_fp32, - b_real_fp32, - b_imag_fp32, - c_real_host_fp32, - c_imag_fp32, - a_element_op, - b_element_op, - c_element_op); + ck::cgemm_util::RunHostCGEMM(a_real_fp32, + a_imag_fp32, + b_real_fp32, + b_imag_fp32, + c_real_host_fp32, + c_imag_host_fp32, + a_element_op, + b_element_op, + c_element_op); // Act - ck::gemm_util::RunDeviceCGEMM(cgemmPtr, - params, - a_real_bf16, - a_imag_bf16, - b_real_bf16, - b_imag_bf16, - c_real_device_bf16, - c_imag_device_bf16, - aux_bf16, - a_element_op, - b_element_op, - c_element_op); + ck::cgemm_util::RunDeviceCGEMM(cgemmPtr, + params, + a_real_bf16, + a_imag_bf16, + b_real_bf16, + b_imag_bf16, + c_real_device_bf16, + c_imag_device_bf16, + aux_bf16, + a_element_op, + b_element_op, + c_element_op); bf16_to_f32_(c_real_device_bf16, c_real_device_fp32); bf16_to_f32_(c_imag_device_bf16, c_imag_device_fp32);