mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 02:27:57 +00:00
Adjust the atol and rtol and fix the check_err() using in example_hstu_attention.cpp
This commit is contained in:
@@ -191,8 +191,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
double rtol = 1.6e-2;
|
||||
double atol = 1e-5;
|
||||
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
@@ -200,8 +200,8 @@ auto get_elimit()
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
double rtol = 1.6e-2;
|
||||
double atol = 1e-5;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
@@ -545,7 +545,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto [rtol, atol] = get_elimit<InOutDataType>();
|
||||
|
||||
res = ck_tile::check_err(
|
||||
o_host, o_host_ref, std::string("hstu_attention output error"), atol, rtol);
|
||||
o_host, o_host_ref, std::string("hstu_attention output error"), rtol, atol);
|
||||
};
|
||||
|
||||
if(measure_perf)
|
||||
|
||||
Reference in New Issue
Block a user