add notes

This commit is contained in:
amd-ruitang3
2025-06-16 11:16:05 +00:00
parent f8190b37a5
commit aba6e663dc

View File

@@ -243,6 +243,20 @@ namespace impl {
// this version only have 2 variation: masking and non-masking
// This is more friendly to codegen (e.g. need generate less kernel)
// ... with the trade-off that may have more instruction in causal mode
// clang-format off
/* y_ratio is used to describe the step length of y-direction changes
in certain performance optimization scenarios like merging seqlen
and qk_head_ratio, for example:
x=1/y=6/y_ratio=2(top-left)
1 * * * * * * *
1 * * * * * * *
1 1 * * * * * *
1 1 * * * * * *
1 1 1 * * * * *
1 1 1 * * * * *
*/
// clang-format on
template <bool IsMasking_ = true>
struct SimplifiedGenericAttentionMask
{
@@ -254,9 +268,11 @@ struct SimplifiedGenericAttentionMask
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
{
}
// TODO: Y or i_y cannot be negative if y_ratio is not equal to 1,
// because integer_divide_floor do not support negative numbers.
CK_TILE_HOST_DEVICE
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t y_ratio_=1)
SimplifiedGenericAttentionMask(
index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t y_ratio_ = 1)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_), y_ratio(y_ratio_)
{
}
@@ -284,14 +300,14 @@ struct SimplifiedGenericAttentionMask
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp = max((-y + i_y + y_ratio) / y_ratio, 0);
index_t tmp = max(integer_divide_floor(-y + i_y + y_ratio, y_ratio), 0);
return (tmp / XTile) * XTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min((i_y + YTile - 1) / y_ratio + x, x_total);
index_t tmp = min(integer_divide_floor(i_y + YTile - 1, y_ratio) + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
@@ -357,8 +373,10 @@ struct SimplifiedGenericAttentionMask
}
else
{
index_t x_start = (-y + i_y + y_ratio) / y_ratio; // this could be negative, but it's fine
index_t x_end = min(i_y / y_ratio + x, x_total); // need min in case x is padded
index_t x_start = integer_divide_floor(
-y + i_y + y_ratio, y_ratio); // this could be negative, but it's fine
index_t x_end = min(integer_divide_floor(i_y, y_ratio) + x,
x_total); // need min in case x is padded
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
@@ -388,8 +406,10 @@ struct SimplifiedGenericAttentionMask
index_t i_y_end = i_y + TileHeight;
// index_t x_end = min(i_y + x, x_total);
bool top_right_edge = i_x_end > min(i_y / y_ratio + x, x_total); // consider right pad
bool bottom_left_edge = i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
bool top_right_edge = i_x_end > min(integer_divide_floor(i_y, y_ratio) + x,
x_total); // consider right pad
bool bottom_left_edge =
i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
return top_right_edge || bottom_left_edge;