From 0be8a1ec2aa306fdaa751303c919fc96d674e590 Mon Sep 17 00:00:00 2001 From: shay-li77 Date: Wed, 9 Jul 2025 01:31:53 +0800 Subject: [PATCH] add SimplifiedRatioAttentionMask --- .../ck_tile/ops/fmha/block/block_masking.hpp | 363 +++++++++++------- 1 file changed, 224 insertions(+), 139 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 51a5952d2f..f5c12e11d2 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -234,42 +234,21 @@ struct GenericAttentionMask // clang-format off namespace impl { - template struct SimplifiedMaskName; - template<> struct SimplifiedMaskName { static constexpr const char * name = "nomask"; }; - template<> struct SimplifiedMaskName { static constexpr const char * name = "mask"; }; - template<> struct SimplifiedMaskName { static constexpr const char * name = "nomask_ratio"; }; - template<> struct SimplifiedMaskName { static constexpr const char * name = "mask_ratio"; }; + template struct SimplifiedMaskName; + template<> struct SimplifiedMaskName { static constexpr const char * name = "nomask"; }; + template<> struct SimplifiedMaskName { static constexpr const char * name = "mask"; }; } // clang-format on // this version only have 2 variation: masking and non-masking // This is more friendly to codegen (e.g. need generate less kernel) // ... with the trade-off that may have more instruction in causal mode - -// clang-format off -/* y_ratio is used to describe the step length of y-direction changes - in certain performance optimization scenarios like merging seqlen - and qk_head_ratio, for example: - - x=1/y=6/y_ratio=2(top-left) - 1 * * * * * * * - 1 * * * * * * * - 1 1 * * * * * * - 1 1 * * * * * * - 1 1 1 * * * * * - 1 1 1 * * * * * - - set EnableRatio_=true to enable this feature -*/ -// clang-format on -template +template struct SimplifiedGenericAttentionMask { static constexpr bool IsMasking = IsMasking_; // false will disable masking - static constexpr bool EnableRatio = EnableRatio_; // false will disable y-ratio - - static constexpr const char* name = impl::SimplifiedMaskName::name; + static constexpr const char* name = impl::SimplifiedMaskName::name; CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_) : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_) @@ -281,18 +260,6 @@ struct SimplifiedGenericAttentionMask : y(y_), x(x_), y_total(y_total_), x_total(x_total_) { } - CK_TILE_HOST_DEVICE - SimplifiedGenericAttentionMask( - index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_) - : y(y_real_ * static_cast(y_ratio_mdiv_.get())), - x(x_), - y_total(y_total_), - x_total(x_total_), - y_real(y_real_), - y_ratio(static_cast(y_ratio_mdiv_.get())), - y_ratio_mdiv(y_ratio_mdiv_) - { - } template CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord) : y(mask_coord.at(number<0>{})), @@ -315,45 +282,20 @@ struct SimplifiedGenericAttentionMask } else { - if constexpr(!EnableRatio) - { - // get the tile start/end range assum we loop over along X tile by tile - index_t x_start = [&]() { - index_t tmp = max(-y + i_y + 1, 0); - return (tmp / XTile) * XTile; // round to tile aligned - }(); + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + index_t tmp = max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + }(); - // TODO: end could be negative, we ignore clamp here, and let caller to check - // ... in which case end-start is negative - index_t x_end = [&]() { - index_t tmp = min(i_y + YTile - 1 + x, x_total); - return ((tmp + XTile - 1) / XTile) * XTile; - }(); + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); - return ck_tile::make_tuple(x_start, x_end); - } - else - { - // get the tile start/end range assum we loop over along X tile by tile - index_t x_start = [&]() { - index_t tmp = - -y_real + - static_cast(y_ratio_mdiv.div(static_cast(i_y))) + 1; - - return (tmp / XTile) * XTile; // round to tile aligned - }(); - - // TODO: end could be negative, we ignore clamp here, and let caller to check - // ... in which case end-start is negative - index_t x_end = [&]() { - uint32_t y_offset = i_y + YTile - 1; - index_t tmp = - min(static_cast(y_ratio_mdiv.div(y_offset)) + x, x_total); - return ((tmp + XTile - 1) / XTile) * XTile; - }(); - - return ck_tile::make_tuple(x_start, x_end); - } + return ck_tile::make_tuple(x_start, x_end); } } @@ -387,40 +329,20 @@ struct SimplifiedGenericAttentionMask } else { - if constexpr(!EnableRatio) - { - // get the tile start/end range assum we loop over along Y tile by tile - index_t y_start = [&]() { - index_t tmp = max(-x + i_x + 1, 0); - return (tmp / YTile) * YTile; // round to tile aligned - }(); + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); - // TODO: end could be negative, we ignore clamp here, and let caller to check - // ... in which case end-start is negative - index_t y_end = [&]() { - index_t tmp = min(i_x + XTile - 1 + y, y_total); - return ((tmp + YTile - 1) / YTile) * YTile; - }(); + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); - return ck_tile::make_tuple(y_start, y_end); - } - else - { - // get the tile start/end range assum we loop over along Y tile by tile - index_t y_start = [&]() { - index_t tmp = max((-x + i_x + 1) * y_ratio, 0); - return (tmp / YTile) * YTile; // round to tile aligned - }(); - - // TODO: end could be negative, we ignore clamp here, and let caller to check - // ... in which case end-start is negative - index_t y_end = [&]() { - index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total); - return ((tmp + YTile - 1) / YTile) * YTile; - }(); - - return ck_tile::make_tuple(y_start, y_end); - } + return ck_tile::make_tuple(y_start, y_end); } } @@ -435,20 +357,10 @@ struct SimplifiedGenericAttentionMask } else { - if constexpr(!EnableRatio) - { - index_t x_start = -y + i_y + 1; // this could be negative, but it's fine - index_t x_end = min(i_y + x, x_total); // need min in case x is padded - return i_x < x_start || i_x >= x_end || i_y >= y_total; - } - else - { - index_t x_tmp = static_cast(y_ratio_mdiv.div(static_cast(i_y))); - index_t x_start = -y_real + x_tmp + 1; - index_t x_end = min(x_tmp + x, - x_total); // need min in case x is padded - return i_x < x_start || i_x >= x_end || i_y >= y_total; - } + index_t x_start = -y + i_y + 1; // this could be negative, but it's fine + index_t x_end = min(i_y + x, x_total); // need min in case x is padded + + return i_x < x_start || i_x >= x_end || i_y >= y_total; } } @@ -476,24 +388,197 @@ struct SimplifiedGenericAttentionMask index_t i_y_end = i_y + TileHeight; // index_t x_end = min(i_y + x, x_total); - if constexpr(!EnableRatio) - { - bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad - bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad - // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for - // now - return top_right_edge || bottom_left_edge; - } - else - { - uint32_t y_tmp = static_cast(i_y); - bool top_right_edge = - i_x_end > min(static_cast(y_ratio_mdiv.div(y_tmp)) + x, - x_total); // consider right pad - bool bottom_left_edge = - i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad - return top_right_edge || bottom_left_edge; - } + bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad + bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad + // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge; + } + } + + private: + index_t y, x; + index_t y_total, x_total; +}; + +// clang-format off +namespace impl { + template struct SimplifiedRatioMaskName; + template<> struct SimplifiedRatioMaskName { static constexpr const char * name = "nomask"; }; + template<> struct SimplifiedRatioMaskName { static constexpr const char * name = "mask"; }; +} +// clang-format on + +// this version is used for cases that the step length of y-direction changes greater than one. It +// means that the mask is not a regular triangular matrix. + +// clang-format off +/* y_ratio is used to describe the step length of y-direction changes + in certain performance optimization scenarios like merging seqlen + and qk_head_ratio, for example: + + x=1/y=6/y_ratio=2(top-left) + 1 * * * * * * * + 1 * * * * * * * + 1 1 * * * * * * + 1 1 * * * * * * + 1 1 1 * * * * * + 1 1 1 * * * * * + +*/ +// clang-format on +template +struct SimplifiedRatioAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + + static constexpr const char* name = impl::SimplifiedRatioMaskName::name; + + CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_) + : SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{}) + { + } + + CK_TILE_HOST_DEVICE + SimplifiedRatioAttentionMask( + index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_) + : SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast(y_ratio_mdiv_.get()), + /*x_=*/x_, + /*y_total_=*/y_total_, + /*x_total_=*/x_total_, + /*y_real_=*/y_real_, + /*y_ratio_=*/static_cast(y_ratio_mdiv_.get()), + /*y_ratio_mdiv_=*/y_ratio_mdiv_) + + { + } + CK_TILE_HOST_DEVICE + SimplifiedRatioAttentionMask(index_t y_, + index_t x_, + index_t y_total_, + index_t x_total_, + index_t y_real_, + index_t y_ratio_, + mdiv y_ratio_mdiv_) + : y(y_), + x(x_), + y_total(y_total_), + x_total(x_total_), + y_real(y_real_), + y_ratio(y_ratio_), + y_ratio_mdiv(y_ratio_mdiv_) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + index_t tmp = -y_real + + static_cast(y_ratio_mdiv.div(static_cast(i_y))) + + 1; + + return (tmp / XTile) * XTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + uint32_t y_offset = i_y + YTile - 1; + index_t tmp = min(static_cast(y_ratio_mdiv.div(y_offset)) + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max((-x + i_x + 1) * y_ratio, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + return i_x >= x_total; + } + else + { + index_t x_tmp = static_cast(y_ratio_mdiv.div(static_cast(i_y))); + index_t x_start = -y_real + x_tmp + 1; + index_t x_end = min(x_tmp + x, + x_total); // need min in case x is padded + return i_x < x_start || i_x >= x_end || i_y >= y_total; + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_y, index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + // return (i_x < x_total) && ((i_x + TileWidth) > x_total); + + return (i_x + TileWidth) > x_total; + } + else + { + // check top-right corner > x or left-borrom corner < x + index_t i_x_end = i_x + TileWidth; + index_t i_y_end = i_y + TileHeight; + // index_t x_end = min(i_y + x, x_total); + uint32_t y_tmp = static_cast(i_y); + bool top_right_edge = i_x_end > min(static_cast(y_ratio_mdiv.div(y_tmp)) + x, + x_total); // consider right pad + bool bottom_left_edge = + i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad + return top_right_edge || bottom_left_edge; } }