diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp index 8037ee5c08..98c96b8b58 100644 --- a/test/gemm/gemm_bf16.cpp +++ b/test/gemm/gemm_bf16.cpp @@ -113,4 +113,5 @@ int main() } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; } diff --git a/test/gemm/gemm_fp16.cpp b/test/gemm/gemm_fp16.cpp index 4ed85d170d..d7669bb242 100644 --- a/test/gemm/gemm_fp16.cpp +++ b/test/gemm/gemm_fp16.cpp @@ -151,4 +151,5 @@ int main() } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; } diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp index 7f73296545..cd68158402 100644 --- a/test/gemm/gemm_fp32.cpp +++ b/test/gemm/gemm_fp32.cpp @@ -151,4 +151,5 @@ int main() } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; } diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index 99073bbd8d..bb3dbdf43b 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -129,4 +129,5 @@ int main() } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; } diff --git a/test/include/test_util.hpp b/test/include/test_util.hpp index 069261f87d..07fe67ba46 100644 --- a/test/include/test_util.hpp +++ b/test/include/test_util.hpp @@ -105,11 +105,11 @@ check_err(const std::vector& out, return res; } -bool check_err(const std::vector<_Float16>& out, - const std::vector<_Float16>& ref, +bool check_err(const std::vector& out, + const std::vector& ref, const std::string& msg, - _Float16 rtol = static_cast<_Float16>(1e-3f), - _Float16 atol = static_cast<_Float16>(1e-3f)) + ck::half_t rtol = static_cast(1e-3f), + ck::half_t atol = static_cast(1e-3f)) { if(out.size() != ref.size()) { @@ -122,7 +122,7 @@ bool check_err(const std::vector<_Float16>& out, bool res{true}; int err_count = 0; double err = 0; - double max_err = std::numeric_limits<_Float16>::min(); + double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { double out_ = double(out[i]);