mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
add notes
This commit is contained in:
@@ -243,6 +243,20 @@ namespace impl {
|
||||
// this version only have 2 variation: masking and non-masking
|
||||
// This is more friendly to codegen (e.g. need generate less kernel)
|
||||
// ... with the trade-off that may have more instruction in causal mode
|
||||
|
||||
// clang-format off
|
||||
/* y_ratio is used to describe the step length of y-direction changes
|
||||
in certain performance optimization scenarios like merging seqlen
|
||||
and qk_head_ratio, for example:
|
||||
x=1/y=6/y_ratio=2(top-left)
|
||||
1 * * * * * * *
|
||||
1 * * * * * * *
|
||||
1 1 * * * * * *
|
||||
1 1 * * * * * *
|
||||
1 1 1 * * * * *
|
||||
1 1 1 * * * * *
|
||||
*/
|
||||
// clang-format on
|
||||
template <bool IsMasking_ = true>
|
||||
struct SimplifiedGenericAttentionMask
|
||||
{
|
||||
@@ -254,9 +268,11 @@ struct SimplifiedGenericAttentionMask
|
||||
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
|
||||
{
|
||||
}
|
||||
|
||||
// TODO: Y or i_y cannot be negative if y_ratio is not equal to 1,
|
||||
// because integer_divide_floor do not support negative numbers.
|
||||
CK_TILE_HOST_DEVICE
|
||||
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t y_ratio_=1)
|
||||
SimplifiedGenericAttentionMask(
|
||||
index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t y_ratio_ = 1)
|
||||
: y(y_), x(x_), y_total(y_total_), x_total(x_total_), y_ratio(y_ratio_)
|
||||
{
|
||||
}
|
||||
@@ -284,14 +300,14 @@ struct SimplifiedGenericAttentionMask
|
||||
{
|
||||
// get the tile start/end range assum we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
index_t tmp = max((-y + i_y + y_ratio) / y_ratio, 0);
|
||||
index_t tmp = max(integer_divide_floor(-y + i_y + y_ratio, y_ratio), 0);
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
index_t tmp = min((i_y + YTile - 1) / y_ratio + x, x_total);
|
||||
index_t tmp = min(integer_divide_floor(i_y + YTile - 1, y_ratio) + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
@@ -357,8 +373,10 @@ struct SimplifiedGenericAttentionMask
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_start = (-y + i_y + y_ratio) / y_ratio; // this could be negative, but it's fine
|
||||
index_t x_end = min(i_y / y_ratio + x, x_total); // need min in case x is padded
|
||||
index_t x_start = integer_divide_floor(
|
||||
-y + i_y + y_ratio, y_ratio); // this could be negative, but it's fine
|
||||
index_t x_end = min(integer_divide_floor(i_y, y_ratio) + x,
|
||||
x_total); // need min in case x is padded
|
||||
|
||||
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
@@ -388,8 +406,10 @@ struct SimplifiedGenericAttentionMask
|
||||
index_t i_y_end = i_y + TileHeight;
|
||||
// index_t x_end = min(i_y + x, x_total);
|
||||
|
||||
bool top_right_edge = i_x_end > min(i_y / y_ratio + x, x_total); // consider right pad
|
||||
bool bottom_left_edge = i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
|
||||
bool top_right_edge = i_x_end > min(integer_divide_floor(i_y, y_ratio) + x,
|
||||
x_total); // consider right pad
|
||||
bool bottom_left_edge =
|
||||
i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
|
||||
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
|
||||
|
||||
return top_right_edge || bottom_left_edge;
|
||||
|
||||
Reference in New Issue
Block a user