Add attn sink (#2892)

* enable attn sink

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update attn_sink script

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* fix some error

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* clang-format

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update fmha_bwd mask

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update fmha_bwd_kernel'mask

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update block_fmha_pipeline_qr_ks_vs.hpp

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* fix ci error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* fix format error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_bwd_pipeline_default_policy.hpp

* Update fmha_fwd_runner.hpp

* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

* Update fmha_fwd_runner.hpp

* Update fmha_fwd_runner.hpp

* Update fmha_fwd_runner.hpp

* update splitkv_pipline

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update splitkv&pagedkv pipeline

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* add sink test

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update attn_sink result log

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update smoke_test_fwd_sink.sh

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update test file

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update test script

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp

* use constexpr kHasSink for sink in fmha pipeline

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* update by pre-commit

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fmha_fwd.py

* Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Remove causal mask setting logic from mask.hpp

Removed the mask setting logic for causal masks.

* fix ci error that some usage of lamada not support in c++17

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update remod.py

* add smoke sink test

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update fmha_pagedkv_prefill.py

* Update FmhaFwdPipeline parameters in fmha_fwd.py

* update block_fmha_pipeline_qr_ks_vs_async_trload.hpp

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* fix c++17 unsupprot error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp

* Fix formatting of sink_seq_end assignment

* Fix indentation for sink_seq_end assignment

* Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp

---------

Signed-off-by: JL-underdog <Jun.Lin@amd.com>
Signed-off-by: LJ-underdog <Jun.Lin@amd.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Linjun-AMD
2025-11-20 19:24:05 +08:00
committed by GitHub
parent 84540edff3
commit 9fa4e8d5ab
25 changed files with 940 additions and 195 deletions

View File

@@ -86,21 +86,22 @@ struct GenericAttentionMask
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
: GenericAttentionMask(0, 0, y_total_, x_total_)
: GenericAttentionMask(0, 0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
sink(mask_coord.at(number<2>{})),
y_total(mask_coord.at(number<3>{})),
x_total(mask_coord.at(number<4>{}))
{
}
@@ -141,6 +142,44 @@ struct GenericAttentionMask
}
}
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, 0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
if constexpr(IsLocal)
{
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}
else
{
return 0;
}
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
if(x_start <= sink_seq_end && sink > 0)
return ck_tile::make_tuple(0, 0, x_end);
else
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
@@ -195,6 +234,30 @@ struct GenericAttentionMask
}
}
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
return i_x >= x_total;
// no need to do min/max here, since i_x will never be < 0 or >= x_total
index_t x_start = -y + i_y + 1;
index_t x_end = min(i_y + x, x_total);
if constexpr(IsLocal)
{
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x < x_start || i_x >= x_end;
}
else
{
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x >= x_end || i_y >= y_total;
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
@@ -237,7 +300,7 @@ struct GenericAttentionMask
}
private:
index_t y, x;
index_t y, x, sink;
index_t y_total, x_total;
};
@@ -260,21 +323,23 @@ struct SimplifiedGenericAttentionMask
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
: SimplifiedGenericAttentionMask(0, 0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
SimplifiedGenericAttentionMask(
index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
sink(mask_coord.at(number<2>{})),
y_total(mask_coord.at(number<3>{})),
x_total(mask_coord.at(number<4>{}))
{
}
@@ -308,6 +373,38 @@ struct SimplifiedGenericAttentionMask
}
}
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, 0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
if(x_start <= sink_seq_end && sink > 0)
return ck_tile::make_tuple(0, 0, x_end);
else
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
}
}
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
number<TileHeight> height,
@@ -325,6 +422,29 @@ struct SimplifiedGenericAttentionMask
ck_tile::min(origin_end, split_end));
}
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto GetSinkTileRangeAlongX(index_t i_y,
number<TileHeight> height,
number<TileWidth> width,
index_t num_splits,
index_t i_split) const
{
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split; // 128
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); // 256
const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0;
const index_t start = ck_tile::max(origin_start, split_start);
const index_t end = ck_tile::min(origin_end, split_end);
const bool is_first_intersecting_split =
(split_start <= origin_start && split_end >= origin_start);
const bool sink_in_range = (sink_seq_end <= start);
const index_t sink_offset =
(is_first_intersecting_split && sink_in_range) ? sink_seq_end : 0;
return ck_tile::make_tuple(sink_offset, start, end);
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
@@ -368,11 +488,22 @@ struct SimplifiedGenericAttentionMask
{
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
}
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
return i_x >= x_total;
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
@@ -406,7 +537,7 @@ struct SimplifiedGenericAttentionMask
}
private:
index_t y, x;
index_t y, x, sink;
index_t y_total, x_total;
};
@@ -607,6 +738,7 @@ struct SimplifiedRatioAttentionMask
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t right_size,
index_t sink_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
@@ -624,7 +756,21 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t x = 1 + right_size + x_tmp;
index_t y = 1 + left_size + y_tmp;
return ck_tile::make_tuple(y, x, y_total, x_total);
return ck_tile::make_tuple(y, x, sink_size, y_total, x_total);
}
template <typename MaskType>
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_from_lr_window(index_t left_size,
index_t right_size,
index_t sink_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, sink_size, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), sink_size, y_total, x_total};
}
template <typename MaskType>
@@ -636,7 +782,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
left_size, right_size, 0, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), 0, y_total, x_total};
}
} // namespace ck_tile