mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Remove using cast_tile_pk_fp16_fp32 for better accuracy for fp16 hstu attention
This commit is contained in:
@@ -162,8 +162,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 2e-2;
|
||||
double atol = 2e-2;
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
@@ -446,14 +446,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window);
|
||||
}
|
||||
|
||||
auto p = [&]() {
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
|
||||
else
|
||||
return cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
|
||||
}();
|
||||
auto p = cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1]));
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user