mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4335 (commit 06976b3)
=?UTF-8?q?Increase=20tolerance=20for=20FP16=20GEMM=20test?= =?UTF-8?q?s=20to=20handle=20non-deterministic=20ro=E2=80=A6=20(#4335)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …unding Three tests were failing intermittently with small errors (0.01-1.5%) due to non-deterministic FP16 accumulation order from GPU thread scheduling: - test_ck_tile_batched_gemm - test_ck_tile_grouped_gemm_preshuffle - test_ck_tile_grouped_gemm_multi_d These tests use kbatch=1 (no split-K), so errors are from order-dependent rounding, not atomics. Increased tolerances from 1e-3 to 2e-3 (0.2%) to account for FP16 precision limits while still catching real bugs. - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
committed by
assistant-librarian[bot]
parent
d2f1541976
commit
4237aedf9a
@@ -255,7 +255,10 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_n_k, c_m_n_host_ref);
|
||||
|
||||
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
|
||||
constexpr double rtol = 2e-3;
|
||||
constexpr double atol = 2e-3;
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -68,22 +68,20 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test
|
||||
|
||||
using ComputeType = std::
|
||||
conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
return ck_tile::make_tuple(std::max({rtol, rtol_split_k, 2e-3}),
|
||||
std::max({atol, atol_split_k, 2e-3}));
|
||||
}
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<DsDataType::size()>;
|
||||
|
||||
@@ -62,18 +62,16 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
return ck_tile::make_tuple(std::max({rtol, rtol_split_k, 2e-3}),
|
||||
std::max({atol, atol_split_k, 2e-3}));
|
||||
}
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>;
|
||||
|
||||
Reference in New Issue
Block a user