Fix masking for min_full_attn_seqlen > 0 situation

This commit is contained in:
Qianfeng Zhang
2025-06-22 16:23:57 +00:00
parent c87a217475
commit 63a47d7ec5

View File

@@ -149,19 +149,31 @@ struct HstuBlockMaskWithLocal
// diagonal line are always considerred
if constexpr(kUseCausal)
{
if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col))
return true;
if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen))
return true;
if(min_full_attn_seqlen > 0)
{
return (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
}
else
{
return (((row_id > col_id) || (row == col)) &&
(row_id - col_id <= max_attn_len));
};
}
else
{
if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col))
return true;
if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen))
return true;
if(min_full_attn_seqlen > 0)
{
return (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
}
else
{
return (((row_id != col_id) || (row == col)) &&
(abs(row_id - col_id) <= max_attn_len));
};
}
return false;
@@ -175,19 +187,31 @@ struct HstuBlockMaskWithLocal
// diagonal line are always considerred
if constexpr(kUseCausal)
{
if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col))
return true;
if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen))
return true;
if(min_full_attn_seqlen > 0)
{
return (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
}
else
{
return (((row_id > col_id) || (row == col)) &&
(row_id - col_id <= max_attn_len));
};
}
else
{
if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col))
return true;
if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen))
return true;
if(min_full_attn_seqlen > 0)
{
return (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
}
else
{
return (((row_id != col_id) || (row == col)) &&
(abs(row_id - col_id) <= max_attn_len));
};
}
return false;
@@ -213,20 +237,39 @@ struct HstuBlockMaskWithLocal
// diagonal line are always considerred
if constexpr(kUseCausal)
{
bool res1 =
(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col));
bool res2 =
((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen));
if(min_full_attn_seqlen > 0)
{
bool res = (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
return static_cast<int>(res1 || res2);
return static_cast<int>(res);
}
else
{
bool res =
(((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len));
return static_cast<int>(res);
};
}
else
{
bool res1 = (((row_id != col_id) && (abs(row_id - col_id) <= max_attn_len)) ||
(row == col));
bool res2 = ((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen));
if(min_full_attn_seqlen > 0)
{
bool res = (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
return static_cast<int>(res1 || res2);
return static_cast<int>(res);
}
else
{
bool res = (((row_id != col_id) || (row == col)) &&
(abs(row_id - col_id) <= max_attn_len));
return static_cast<int>(res);
};
}
}
else
@@ -240,20 +283,39 @@ struct HstuBlockMaskWithLocal
// diagonal line are always considerred
if constexpr(kUseCausal)
{
bool res1 =
(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col));
bool res2 =
((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen));
if(min_full_attn_seqlen > 0)
{
bool res = (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
return static_cast<int>(res1 || res2);
return static_cast<int>(res);
}
else
{
bool res =
(((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len));
return static_cast<int>(res);
};
}
else
{
bool res1 = (((row_id != col_id) && (abs(row_id - col_id) <= max_attn_len)) ||
(row == col));
bool res2 = ((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen));
if(min_full_attn_seqlen > 0)
{
bool res = (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
return static_cast<int>(res1 || res2);
return static_cast<int>(res);
}
else
{
bool res = (((row_id != col_id) || (row == col)) &&
(abs(row_id - col_id) <= max_attn_len));
return static_cast<int>(res);
};
}
}
};