Adjust the threshold values for fp16/bf16 in the example

This commit is contained in:
Qianfeng Zhang
2025-05-20 07:48:54 +00:00
parent 29cf1610f1
commit 0a8ea6bd02

View File

@@ -163,8 +163,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
@@ -172,8 +172,8 @@ auto get_elimit()
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 2e-2;
double atol = 2e-2;
double rtol = 2e-3;
double atol = 2e-3;
return ck_tile::make_tuple(rtol, atol);
}
@@ -475,7 +475,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
if(save_mask)
dumpBufferToFile("ck_hstu_mask.dat", mask_host.data(), mask_host.get_element_space_size());
dumpBufferToFile(
"ck_hstu_mask.dat", mask_host.data(), mask_host.get_element_space_size());
auto [rtol, atol] = get_elimit<InOutDataType>();