Remove using cast_tile_pk_fp16_fp32 for better accuracy for fp16 hstu attention

This commit is contained in:
Qianfeng Zhang
2025-05-06 08:24:03 +00:00
parent 611f2ce1f9
commit 374e0626e6
2 changed files with 4 additions and 10 deletions

View File

@@ -162,8 +162,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara
template <typename DataType>
auto get_elimit()
{
double rtol = 2e-2;
double atol = 2e-2;
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}

View File

@@ -446,14 +446,8 @@ struct HstuAttentionFwdPipelineQRKSVS
randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window);
}
auto p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
else
return cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
}();
auto p = cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
block_sync_lds();