mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Adjust the threshold values for fp16/bf16 in the example
This commit is contained in:
@@ -163,8 +163,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
@@ -172,8 +172,8 @@ auto get_elimit()
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 2e-2;
|
||||
double atol = 2e-2;
|
||||
double rtol = 2e-3;
|
||||
double atol = 2e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
@@ -475,7 +475,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
|
||||
|
||||
if(save_mask)
|
||||
dumpBufferToFile("ck_hstu_mask.dat", mask_host.data(), mask_host.get_element_space_size());
|
||||
dumpBufferToFile(
|
||||
"ck_hstu_mask.dat", mask_host.data(), mask_host.get_element_space_size());
|
||||
|
||||
auto [rtol, atol] = get_elimit<InOutDataType>();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user