mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Fix masking for min_full_attn_seqlen > 0 situation
This commit is contained in:
@@ -149,19 +149,31 @@ struct HstuBlockMaskWithLocal
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col))
|
||||
return true;
|
||||
|
||||
if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen))
|
||||
return true;
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
return (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
}
|
||||
else
|
||||
{
|
||||
return (((row_id > col_id) || (row == col)) &&
|
||||
(row_id - col_id <= max_attn_len));
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col))
|
||||
return true;
|
||||
|
||||
if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen))
|
||||
return true;
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
}
|
||||
else
|
||||
{
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
(abs(row_id - col_id) <= max_attn_len));
|
||||
};
|
||||
}
|
||||
|
||||
return false;
|
||||
@@ -175,19 +187,31 @@ struct HstuBlockMaskWithLocal
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col))
|
||||
return true;
|
||||
|
||||
if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen))
|
||||
return true;
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
return (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
}
|
||||
else
|
||||
{
|
||||
return (((row_id > col_id) || (row == col)) &&
|
||||
(row_id - col_id <= max_attn_len));
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col))
|
||||
return true;
|
||||
|
||||
if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen))
|
||||
return true;
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
}
|
||||
else
|
||||
{
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
(abs(row_id - col_id) <= max_attn_len));
|
||||
};
|
||||
}
|
||||
|
||||
return false;
|
||||
@@ -213,20 +237,39 @@ struct HstuBlockMaskWithLocal
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
bool res1 =
|
||||
(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col));
|
||||
bool res2 =
|
||||
((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen));
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
bool res = (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
|
||||
return static_cast<int>(res1 || res2);
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res =
|
||||
(((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len));
|
||||
|
||||
return static_cast<int>(res);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res1 = (((row_id != col_id) && (abs(row_id - col_id) <= max_attn_len)) ||
|
||||
(row == col));
|
||||
bool res2 = ((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen));
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
|
||||
return static_cast<int>(res1 || res2);
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
(abs(row_id - col_id) <= max_attn_len));
|
||||
|
||||
return static_cast<int>(res);
|
||||
};
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -240,20 +283,39 @@ struct HstuBlockMaskWithLocal
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
bool res1 =
|
||||
(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col));
|
||||
bool res2 =
|
||||
((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen));
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
bool res = (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
|
||||
return static_cast<int>(res1 || res2);
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res =
|
||||
(((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len));
|
||||
|
||||
return static_cast<int>(res);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res1 = (((row_id != col_id) && (abs(row_id - col_id) <= max_attn_len)) ||
|
||||
(row == col));
|
||||
bool res2 = ((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen));
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
|
||||
return static_cast<int>(res1 || res2);
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
(abs(row_id - col_id) <= max_attn_len));
|
||||
|
||||
return static_cast<int>(res);
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user