diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 37005cccc1..5e0a930ed3 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -255,7 +255,10 @@ class TestCkTileBatchedGemm : public ::testing::Test ck_tile::reference_batched_gemm( a_m_k, b_n_k, c_m_n_host_ref); - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); + constexpr double rtol = 2e-3; + constexpr double atol = 2e-3; + pass = ck_tile::check_err( + c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol); EXPECT_TRUE(pass); } }; diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index c6e311a65c..0fb350b32d 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -68,22 +68,20 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test using ComputeType = std:: conditional_t; - // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); - // Calculate error due to split_k accumulation const auto rtol_split_k = ck_tile::get_relative_threshold(kbatch); const auto atol_split_k = ck_tile::get_absolute_threshold( max_accumulated_value, kbatch); - // Use higher threshold - return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + return ck_tile::make_tuple(std::max({rtol, rtol_split_k, 2e-3}), + std::max({atol, atol_split_k, 2e-3})); } using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index e588ad2cc1..f6da94829d 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -62,18 +62,16 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test { using ComputeType = std::conditional_t; - // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); - // Calculate error due to split_k accumulation const auto rtol_split_k = ck_tile::get_relative_threshold(kbatch); const auto atol_split_k = ck_tile::get_absolute_threshold( max_accumulated_value, kbatch); - // Use higher threshold - return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + return ck_tile::make_tuple(std::max({rtol, rtol_split_k, 2e-3}), + std::max({atol, atol_split_k, 2e-3})); } using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>;