From 0a1eb8381d2bd7f5c1a69760b702a775aa2897ca Mon Sep 17 00:00:00 2001 From: shay-li77 Date: Wed, 9 Jul 2025 23:18:55 +0800 Subject: [PATCH] support y-direction step length greater than 1 for SimplifiedGenericAttentionMask (#2338) * mask support ratio for y axis * format code * add notes for param y_ratio * fix comments error * support template and mdiv for ratio mask * refactor y-ratio mask constructor * optimize coordinate calculation * add SimplifiedRatioAttentionMask [ROCm/composable_kernel commit: d814fefe1898971b2c3eb97b986bdafc450f18b5] --- .../ck_tile/ops/fmha/block/block_masking.hpp | 190 ++++++++++++++++++ 1 file changed, 190 insertions(+) diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 726543b97a..f5c12e11d2 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -401,6 +401,196 @@ struct SimplifiedGenericAttentionMask index_t y_total, x_total; }; +// clang-format off +namespace impl { + template struct SimplifiedRatioMaskName; + template<> struct SimplifiedRatioMaskName { static constexpr const char * name = "nomask"; }; + template<> struct SimplifiedRatioMaskName { static constexpr const char * name = "mask"; }; +} +// clang-format on + +// this version is used for cases that the step length of y-direction changes greater than one. It +// means that the mask is not a regular triangular matrix. + +// clang-format off +/* y_ratio is used to describe the step length of y-direction changes + in certain performance optimization scenarios like merging seqlen + and qk_head_ratio, for example: + + x=1/y=6/y_ratio=2(top-left) + 1 * * * * * * * + 1 * * * * * * * + 1 1 * * * * * * + 1 1 * * * * * * + 1 1 1 * * * * * + 1 1 1 * * * * * + +*/ +// clang-format on +template +struct SimplifiedRatioAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + + static constexpr const char* name = impl::SimplifiedRatioMaskName::name; + + CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_) + : SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{}) + { + } + + CK_TILE_HOST_DEVICE + SimplifiedRatioAttentionMask( + index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_) + : SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast(y_ratio_mdiv_.get()), + /*x_=*/x_, + /*y_total_=*/y_total_, + /*x_total_=*/x_total_, + /*y_real_=*/y_real_, + /*y_ratio_=*/static_cast(y_ratio_mdiv_.get()), + /*y_ratio_mdiv_=*/y_ratio_mdiv_) + + { + } + CK_TILE_HOST_DEVICE + SimplifiedRatioAttentionMask(index_t y_, + index_t x_, + index_t y_total_, + index_t x_total_, + index_t y_real_, + index_t y_ratio_, + mdiv y_ratio_mdiv_) + : y(y_), + x(x_), + y_total(y_total_), + x_total(x_total_), + y_real(y_real_), + y_ratio(y_ratio_), + y_ratio_mdiv(y_ratio_mdiv_) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + index_t tmp = -y_real + + static_cast(y_ratio_mdiv.div(static_cast(i_y))) + + 1; + + return (tmp / XTile) * XTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + uint32_t y_offset = i_y + YTile - 1; + index_t tmp = min(static_cast(y_ratio_mdiv.div(y_offset)) + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max((-x + i_x + 1) * y_ratio, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + return i_x >= x_total; + } + else + { + index_t x_tmp = static_cast(y_ratio_mdiv.div(static_cast(i_y))); + index_t x_start = -y_real + x_tmp + 1; + index_t x_end = min(x_tmp + x, + x_total); // need min in case x is padded + return i_x < x_start || i_x >= x_end || i_y >= y_total; + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_y, index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + // return (i_x < x_total) && ((i_x + TileWidth) > x_total); + + return (i_x + TileWidth) > x_total; + } + else + { + // check top-right corner > x or left-borrom corner < x + index_t i_x_end = i_x + TileWidth; + index_t i_y_end = i_y + TileHeight; + // index_t x_end = min(i_y + x, x_total); + uint32_t y_tmp = static_cast(i_y); + bool top_right_edge = i_x_end > min(static_cast(y_ratio_mdiv.div(y_tmp)) + x, + x_total); // consider right pad + bool bottom_left_edge = + i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad + return top_right_edge || bottom_left_edge; + } + } + + private: + index_t y, x; + index_t y_total, x_total; + // y_real is vertical axis before multiplying y_ratio. y_real * y_ratio = y + index_t y_real; + index_t y_ratio; + mdiv y_ratio_mdiv; +}; + // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask