Use type_convert to convert float constant to CompDataType

This commit is contained in:
Qianfeng Zhang
2026-04-24 15:46:26 +00:00
parent 1f2e2a272e
commit b9d4be0982
2 changed files with 2 additions and 2 deletions

View File

@@ -472,7 +472,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
{
// use the m_old[i] as the m-for-stablization if m_old[i] - m[i] >= -8.0f
// and still keep the m-for-stablization in m[]
if(m_old[i_idx] - m[i_idx] >= -8.0f)
if(m_old[i_idx] - m[i_idx] >= type_convert<CompDataType>(-8.0f))
m(i_idx) = m_old[i_idx];
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {

View File

@@ -478,7 +478,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
{
// use the m_old[i] as the m-for-stablization if m_old[i] - m[i] >= -8.0f
// and still keep the m-for-stablization in m[]
if(m_old[i_idx] - m[i_idx] >= -8.0f)
if(m_old[i_idx] - m[i_idx] >= type_convert<CompDataType>(-8.0f))
m(i_idx) = m_old[i_idx];
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {