diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 5721aa34fa..33e4e12ac5 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -243,6 +243,20 @@ namespace impl { // this version only have 2 variation: masking and non-masking // This is more friendly to codegen (e.g. need generate less kernel) // ... with the trade-off that may have more instruction in causal mode + +// clang-format off +/* y_ratio is used to describe the step length of y-direction changes + in certain performance optimization scenarios like merging seqlen + and qk_head_ratio, for example: +x=1/y=6/y_ratio=2(top-left) +1 * * * * * * * +1 * * * * * * * +1 1 * * * * * * +1 1 * * * * * * +1 1 1 * * * * * +1 1 1 * * * * * +*/ +// clang-format on template struct SimplifiedGenericAttentionMask { @@ -254,9 +268,11 @@ struct SimplifiedGenericAttentionMask : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_) { } - + // TODO: Y or i_y cannot be negative if y_ratio is not equal to 1, + // because integer_divide_floor do not support negative numbers. CK_TILE_HOST_DEVICE - SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t y_ratio_=1) + SimplifiedGenericAttentionMask( + index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t y_ratio_ = 1) : y(y_), x(x_), y_total(y_total_), x_total(x_total_), y_ratio(y_ratio_) { } @@ -284,14 +300,14 @@ struct SimplifiedGenericAttentionMask { // get the tile start/end range assum we loop over along X tile by tile index_t x_start = [&]() { - index_t tmp = max((-y + i_y + y_ratio) / y_ratio, 0); + index_t tmp = max(integer_divide_floor(-y + i_y + y_ratio, y_ratio), 0); return (tmp / XTile) * XTile; // round to tile aligned }(); // TODO: end could be negative, we ignore clamp here, and let caller to check // ... in which case end-start is negative index_t x_end = [&]() { - index_t tmp = min((i_y + YTile - 1) / y_ratio + x, x_total); + index_t tmp = min(integer_divide_floor(i_y + YTile - 1, y_ratio) + x, x_total); return ((tmp + XTile - 1) / XTile) * XTile; }(); @@ -357,8 +373,10 @@ struct SimplifiedGenericAttentionMask } else { - index_t x_start = (-y + i_y + y_ratio) / y_ratio; // this could be negative, but it's fine - index_t x_end = min(i_y / y_ratio + x, x_total); // need min in case x is padded + index_t x_start = integer_divide_floor( + -y + i_y + y_ratio, y_ratio); // this could be negative, but it's fine + index_t x_end = min(integer_divide_floor(i_y, y_ratio) + x, + x_total); // need min in case x is padded return i_x < x_start || i_x >= x_end || i_y >= y_total; } @@ -388,8 +406,10 @@ struct SimplifiedGenericAttentionMask index_t i_y_end = i_y + TileHeight; // index_t x_end = min(i_y + x, x_total); - bool top_right_edge = i_x_end > min(i_y / y_ratio + x, x_total); // consider right pad - bool bottom_left_edge = i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad + bool top_right_edge = i_x_end > min(integer_divide_floor(i_y, y_ratio) + x, + x_total); // consider right pad + bool bottom_left_edge = + i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now return top_right_edge || bottom_left_edge;