Revert "Add attn sink (#2892)" (#3250)

This reverts commit bbe1d3a917ee92655224c0f1528ace3a7b0e82a8.

[ROCm/composable_kernel commit: 5adaa201ed]
This commit is contained in:
asleepzzz
2025-11-20 23:55:15 +08:00
committed by GitHub
parent f4ba63deb7
commit 06d2e609cd
25 changed files with 195 additions and 940 deletions

View File

@@ -86,22 +86,21 @@ struct GenericAttentionMask
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
: GenericAttentionMask(0, 0, 0, y_total_, x_total_)
: GenericAttentionMask(0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
sink(mask_coord.at(number<2>{})),
y_total(mask_coord.at(number<3>{})),
x_total(mask_coord.at(number<4>{}))
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
{
}
@@ -142,44 +141,6 @@ struct GenericAttentionMask
}
}
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, 0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
if constexpr(IsLocal)
{
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}
else
{
return 0;
}
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
if(x_start <= sink_seq_end && sink > 0)
return ck_tile::make_tuple(0, 0, x_end);
else
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
@@ -234,30 +195,6 @@ struct GenericAttentionMask
}
}
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
return i_x >= x_total;
// no need to do min/max here, since i_x will never be < 0 or >= x_total
index_t x_start = -y + i_y + 1;
index_t x_end = min(i_y + x, x_total);
if constexpr(IsLocal)
{
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x < x_start || i_x >= x_end;
}
else
{
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x >= x_end || i_y >= y_total;
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
@@ -300,7 +237,7 @@ struct GenericAttentionMask
}
private:
index_t y, x, sink;
index_t y, x;
index_t y_total, x_total;
};
@@ -323,23 +260,21 @@ struct SimplifiedGenericAttentionMask
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
: SimplifiedGenericAttentionMask(0, 0, 0, y_total_, x_total_)
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
SimplifiedGenericAttentionMask(
index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
sink(mask_coord.at(number<2>{})),
y_total(mask_coord.at(number<3>{})),
x_total(mask_coord.at(number<4>{}))
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
{
}
@@ -373,38 +308,6 @@ struct SimplifiedGenericAttentionMask
}
}
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, 0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
if(x_start <= sink_seq_end && sink > 0)
return ck_tile::make_tuple(0, 0, x_end);
else
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
}
}
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
number<TileHeight> height,
@@ -422,29 +325,6 @@ struct SimplifiedGenericAttentionMask
ck_tile::min(origin_end, split_end));
}
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto GetSinkTileRangeAlongX(index_t i_y,
number<TileHeight> height,
number<TileWidth> width,
index_t num_splits,
index_t i_split) const
{
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split; // 128
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); // 256
const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0;
const index_t start = ck_tile::max(origin_start, split_start);
const index_t end = ck_tile::min(origin_end, split_end);
const bool is_first_intersecting_split =
(split_start <= origin_start && split_end >= origin_start);
const bool sink_in_range = (sink_seq_end <= start);
const index_t sink_offset =
(is_first_intersecting_split && sink_in_range) ? sink_seq_end : 0;
return ck_tile::make_tuple(sink_offset, start, end);
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
@@ -488,22 +368,11 @@ struct SimplifiedGenericAttentionMask
{
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
}
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
return i_x >= x_total;
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
@@ -537,7 +406,7 @@ struct SimplifiedGenericAttentionMask
}
private:
index_t y, x, sink;
index_t y, x;
index_t y_total, x_total;
};
@@ -738,7 +607,6 @@ struct SimplifiedRatioAttentionMask
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t right_size,
index_t sink_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
@@ -756,21 +624,7 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t x = 1 + right_size + x_tmp;
index_t y = 1 + left_size + y_tmp;
return ck_tile::make_tuple(y, x, sink_size, y_total, x_total);
}
template <typename MaskType>
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_from_lr_window(index_t left_size,
index_t right_size,
index_t sink_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, sink_size, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), sink_size, y_total, x_total};
return ck_tile::make_tuple(y, x, y_total, x_total);
}
template <typename MaskType>
@@ -782,7 +636,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, 0, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), 0, y_total, x_total};
left_size, right_size, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
}
} // namespace ck_tile

View File

@@ -162,17 +162,6 @@ struct StandardAttention
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
template <typename Params>
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
}
};
template <bool UseExp2 = false>
@@ -235,17 +224,6 @@ struct LogitsSoftCap
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
template <typename Params>
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
}
};
constexpr uint32_t CUSTOM_MASK = 1U;
@@ -319,17 +297,6 @@ struct ComposedAttention
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
template <typename Params>
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
}
};
} // namespace ck_tile

View File

@@ -198,7 +198,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -362,7 +362,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -426,7 +425,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -511,7 +509,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_v,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -573,7 +570,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -1030,7 +1026,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -56,7 +56,6 @@ struct FmhaFwdKernel
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
@@ -113,7 +112,7 @@ struct FmhaFwdKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload") + (kHasSink ? "_sink" : "_nsink");
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload");
#undef _SS_
#undef _TS_
// clang-format on
@@ -201,7 +200,7 @@ struct FmhaFwdKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -375,7 +374,6 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -434,7 +432,6 @@ struct FmhaFwdKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -521,7 +518,6 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -569,7 +565,6 @@ struct FmhaFwdKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
p_drop,
s_randval,
@@ -620,7 +615,6 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -668,7 +662,6 @@ struct FmhaFwdKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
p_drop,
s_randval,
@@ -713,7 +706,6 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -773,7 +765,6 @@ struct FmhaFwdKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -857,7 +848,6 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -901,7 +891,6 @@ struct FmhaFwdKernel
nhead_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q,
p_drop,
@@ -948,7 +937,6 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -992,7 +980,6 @@ struct FmhaFwdKernel
nhead_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q,
p_drop,
@@ -1484,7 +1471,6 @@ struct FmhaFwdKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
@@ -2214,7 +2200,6 @@ struct FmhaFwdKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -55,7 +55,6 @@ struct FmhaFwdPagedKVKernel
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
@@ -102,7 +101,7 @@ struct FmhaFwdPagedKVKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
#undef _SS_
#undef _TS_
// clang-format on
@@ -190,7 +189,7 @@ struct FmhaFwdPagedKVKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -327,7 +326,6 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -381,7 +379,6 @@ struct FmhaFwdPagedKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -456,7 +453,6 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
return MakeKargsImpl(q_ptr,
@@ -499,7 +495,6 @@ struct FmhaFwdPagedKVKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type);
}
@@ -541,7 +536,6 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q)
{
@@ -596,7 +590,6 @@ struct FmhaFwdPagedKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -667,7 +660,6 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q)
{
@@ -707,7 +699,6 @@ struct FmhaFwdPagedKVKernel
batch_stride_v,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q);
}
@@ -1266,7 +1257,6 @@ struct FmhaFwdPagedKVKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -51,7 +51,6 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink;
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
@@ -102,7 +101,7 @@ struct FmhaFwdSplitKVKernel
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
#undef _SS_
#undef _TS_
// clang-format on
@@ -199,7 +198,7 @@ struct FmhaFwdSplitKVKernel
struct MaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -326,7 +325,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -386,7 +384,6 @@ struct FmhaFwdSplitKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kDoFp8StaticQuant)
@@ -454,7 +451,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -512,7 +508,6 @@ struct FmhaFwdSplitKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kDoFp8StaticQuant)
@@ -999,7 +994,6 @@ struct FmhaFwdSplitKVKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -57,7 +57,6 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -229,22 +228,10 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
else
{
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
@@ -268,6 +255,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
return o_acc;
}
}
// k_dram_block_window
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
@@ -286,36 +274,27 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
return physical_seqlen_k_start_;
}
}();
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
? aligned_physical_seqlen_k_start
: 0;
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
num_sink_loop;
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {kv_load_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
const index_t bias_n_offset = [&]() {
if constexpr(kHasSink)
return kv_load_start;
else
return logical_seqlen_k_start -
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
}();
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), bias_n_offset},
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
// v_dram_window
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, kv_load_start}, // TODO: hdim split?
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
// prefetch K tile
@@ -342,16 +321,9 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
k_block_tile = load_tile(k_dram_window);
}
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
const auto k_move_offset = [&]() {
if constexpr(kHasSink)
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
else
return kN0;
}();
auto physical_next_block_id_k =
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
i_page_block_k, k_dram_block_window, {kN0, 0}));
auto physical_next_block_id_v = amd_wave_read_first_lane(
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
@@ -470,7 +442,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
#endif
}
}
move_tile_window(bias_dram_window, {0, k_move_offset});
move_tile_window(bias_dram_window, {0, kN0});
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
@@ -502,29 +474,14 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask_func(row, col - kv_l2p_offset);
});
};
if constexpr(kHasSink)
{
apply_mask([&](auto row, auto col) {
return mask.IsOutOfSinkBound(row, col);
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
}
else
{
apply_mask(
[&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
}
}
}
}
@@ -690,12 +647,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
}
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
physical_next_block_id_v =
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
// tail
{
block_sync_lds();

View File

@@ -57,7 +57,6 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -257,23 +256,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
else
{
auto [start, end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
const index_t logical_num_total_loop =
@@ -317,33 +304,24 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
return physical_seqlen_k_start_;
}
}();
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
? aligned_physical_seqlen_k_start
: 0;
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
num_sink_loop;
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {kv_load_start, 0});
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
const index_t bias_n_offset = [&]() {
if constexpr(kHasSink)
return kv_load_start;
else
return logical_seqlen_k_start -
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), bias_n_offset},
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, kv_load_start}, // TODO: hdim split?
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// store Q into LDS
@@ -391,13 +369,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
const auto k_move_offset = [&]() {
if constexpr(kHasSink)
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
else
return kN0;
}();
// load the second tile of the first iteration
k_block_tile = load_tile(k_dram_window);
@@ -522,7 +494,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
#endif
}
}
move_tile_window(bias_dram_window, {0, k_move_offset});
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
@@ -533,7 +505,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
@@ -558,26 +530,12 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask_func(row, col - kv_l2p_offset);
});
};
if constexpr(kHasSink)
{
apply_mask(
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
}
else
{
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
}
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
}
}
@@ -588,7 +546,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {k_move_offset, 0});
i_page_block_k, k_dram_block_window, {kN0, 0});
k_dram_window = make_tile_window(
k_dram_block_window,
@@ -784,8 +742,6 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0});
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
}
} while(++i_total_loops < num_total_loop);

View File

@@ -56,7 +56,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -230,23 +229,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
else
{
auto [start, end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
@@ -289,35 +274,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return physical_seqlen_k_start_;
}
}();
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
? aligned_physical_seqlen_k_start
: 0;
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
num_sink_loop;
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {kv_load_start, 0});
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
const index_t bias_n_offset = [&]() {
if constexpr(kHasSink)
return kv_load_start;
else
return logical_seqlen_k_start -
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
}();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), bias_n_offset},
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, kv_load_start}, // TODO: hdim split?
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
@@ -346,18 +320,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
k_block_tile = load_tile(k_dram_window);
}
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
const auto k_move_offset = [&]() {
if constexpr(kHasSink)
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
else
return kN0;
}();
auto physical_next_block_id_k =
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
i_page_block_k, k_dram_block_window, {kN0, 0}));
auto physical_next_block_id_v = amd_wave_read_first_lane(
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
@@ -476,7 +441,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
#endif
}
}
move_tile_window(bias_dram_window, {0, k_move_offset});
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
@@ -487,7 +452,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
@@ -512,26 +477,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask_func(row, col - kv_l2p_offset);
});
};
if constexpr(kHasSink)
{
apply_mask(
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
}
else
{
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
}
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
}
}
@@ -696,12 +647,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
physical_next_block_id_v =
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
// tail
{
block_sync_lds();

View File

@@ -62,7 +62,6 @@ struct BlockFmhaPipelineProblem
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
};
template <typename QDataType_,
@@ -115,7 +114,6 @@ struct BlockFmhaFwdPagedKVPipelineProblem
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
};
template <typename QDataType_,
@@ -169,7 +167,6 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
};
// extract tile size attributes to remove dependency on traits

View File

@@ -57,7 +57,6 @@ struct BlockFmhaPipelineQRKSVS
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasSink = Problem::kHasSink;
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
@@ -234,26 +233,10 @@ struct BlockFmhaPipelineQRKSVS
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
else
{
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
const auto num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
@@ -279,22 +262,22 @@ struct BlockFmhaPipelineQRKSVS
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{kv_load_start, 0});
{seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, kv_load_start);
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, kv_load_start}, // TODO: hdim split?
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
@@ -467,11 +450,6 @@ struct BlockFmhaPipelineQRKSVS
#endif
}
}
if constexpr(kHasSink)
{
if(i_total_loops == 0)
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
@@ -482,34 +460,17 @@ struct BlockFmhaPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
if constexpr(kHasSink)
{
apply_mask([&](auto&&... args) {
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
else
{
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
}
@@ -619,23 +580,11 @@ struct BlockFmhaPipelineQRKSVS
if constexpr(kHasDropout)
{
// K and dropout use the same address in LDS, finish loading from k_lds_window by
// gemm_0 to reuse LDS.
block_sync_lds();
auto randval_ptr = reinterpret_cast<char*>(smem_ptr);
index_t seq_offset = [&]() {
if constexpr(!kHasSink)
return seqlen_k_start + i_total_loops * kN0;
const bool in_sink_phase = (num_sink_loop > i_total_loops);
if(i_total_loops == num_sink_loop)
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
}();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr, seq_offset, p_compute, randval_dram_window);
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
}
block_sync_lds();
@@ -687,14 +636,6 @@ struct BlockFmhaPipelineQRKSVS
});
}
// move K tile windows
if constexpr(kHasSink)
{
if(i_total_loops == 0)
{
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
}
}
move_tile_window(k_dram_block_window, {kN0, 0});
// tail
{

View File

@@ -62,7 +62,6 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -278,26 +277,11 @@ struct BlockFmhaPipelineQRKSVSAsync
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
else
{
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
const auto num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
@@ -325,7 +309,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{kv_load_start, 0});
{seqlen_k_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
@@ -348,16 +332,16 @@ struct BlockFmhaPipelineQRKSVSAsync
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, kv_load_start);
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, kv_load_start}, // TODO: hdim split?
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
@@ -494,11 +478,6 @@ struct BlockFmhaPipelineQRKSVSAsync
#endif
}
}
if constexpr(kHasSink)
{
if(i_total_loops == 0)
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
@@ -510,34 +489,17 @@ struct BlockFmhaPipelineQRKSVSAsync
if(need_perpixel_check)
{
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
if constexpr(kHasSink)
{
apply_mask([&](auto&&... args) {
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
else
{
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
}
@@ -685,21 +647,11 @@ struct BlockFmhaPipelineQRKSVSAsync
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
index_t seq_offset = [&]() {
if constexpr(!kHasSink)
return seqlen_k_start + i_total_loops * kN0;
const bool in_sink_phase = (num_sink_loop > i_total_loops);
if(i_total_loops == num_sink_loop)
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
}();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr, seq_offset, p_compute, randval_dram_window);
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
const auto p = [&]() {
@@ -765,16 +717,8 @@ struct BlockFmhaPipelineQRKSVSAsync
i_total_loops++;
if(i_total_loops < num_total_loop)
{
if constexpr(kHasSink)
{
if(i_total_loops == 0)
{
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
}
}
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(k1_loops >= 2 &&

View File

@@ -69,7 +69,6 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasUnevenSplits = true;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||

View File

@@ -19,9 +19,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kStoreLSE_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
bool kHasSink_ = false>
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
struct TileFmhaTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -36,7 +35,6 @@ struct TileFmhaTraits
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
@@ -66,9 +64,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
bool kIsPagedKV_,
bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
bool kHasSink_ = false>
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
struct TileFmhaFwdPagedKVTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -83,7 +80,6 @@ struct TileFmhaFwdPagedKVTraits
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
@@ -98,8 +94,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kIsPagedKV_,
bool kHasUnevenSplits_,
bool kMergeNumHeadGroupsSeqLenQ_ = false,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kHasSink_ = false>
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdSplitKVTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -116,7 +111,6 @@ struct TileFmhaFwdSplitKVTraits
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kHasSink = kHasSink_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,