From ee0c84a85ed38f9e164e7631483cae17a1b1507c Mon Sep 17 00:00:00 2001 From: rocking5566 Date: Fri, 25 Mar 2022 05:26:14 +0800 Subject: [PATCH] Gemm test return value (#148) * Add return value * Replace _Float16 to ck::half_t * A test should return 0 if success and return non-zero if fail [ROCm/composable_kernel commit: 3ba149328f2704e096b2eed7ffeacff0b54fdc8b] --- test/gemm/gemm_bf16.cpp | 1 + test/gemm/gemm_fp16.cpp | 1 + test/gemm/gemm_fp32.cpp | 1 + test/gemm/gemm_int8.cpp | 1 + test/include/test_util.hpp | 10 +++++----- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp index 8037ee5c08..98c96b8b58 100644 --- a/test/gemm/gemm_bf16.cpp +++ b/test/gemm/gemm_bf16.cpp @@ -113,4 +113,5 @@ int main() } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; } diff --git a/test/gemm/gemm_fp16.cpp b/test/gemm/gemm_fp16.cpp index 4ed85d170d..d7669bb242 100644 --- a/test/gemm/gemm_fp16.cpp +++ b/test/gemm/gemm_fp16.cpp @@ -151,4 +151,5 @@ int main() } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; } diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp index 7f73296545..cd68158402 100644 --- a/test/gemm/gemm_fp32.cpp +++ b/test/gemm/gemm_fp32.cpp @@ -151,4 +151,5 @@ int main() } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; } diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index 99073bbd8d..bb3dbdf43b 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -129,4 +129,5 @@ int main() } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; } diff --git a/test/include/test_util.hpp b/test/include/test_util.hpp index 069261f87d..07fe67ba46 100644 --- a/test/include/test_util.hpp +++ b/test/include/test_util.hpp @@ -105,11 +105,11 @@ check_err(const std::vector& out, return res; } -bool check_err(const std::vector<_Float16>& out, - const std::vector<_Float16>& ref, +bool check_err(const std::vector& out, + const std::vector& ref, const std::string& msg, - _Float16 rtol = static_cast<_Float16>(1e-3f), - _Float16 atol = static_cast<_Float16>(1e-3f)) + ck::half_t rtol = static_cast(1e-3f), + ck::half_t atol = static_cast(1e-3f)) { if(out.size() != ref.size()) { @@ -122,7 +122,7 @@ bool check_err(const std::vector<_Float16>& out, bool res{true}; int err_count = 0; double err = 0; - double max_err = std::numeric_limits<_Float16>::min(); + double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { double out_ = double(out[i]);