mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add attention sink support for FMHA FWD (#3368)
* Revert "Revert "Add attn sink (#2892)" (#3250)"
This reverts commit 5adaa201ed.
* fix conflict
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Add F_sink parameter to FmhaFwdPipeline
* Update tile_fmha_traits.hpp
* Refactor pipeline creation in fmha_fwd.py
Updated the pipeline creation logic to include 'sink' parameter in product combinations and adjusted the FmhaFwdPipeline calls accordingly.
* Update fmha_fwd.py
* Update fmha_fwd.py
* Update example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* update CHANGELOG.md
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update CHANGELOG with new features and support
* Update fmha_fwd.hpp
* Update CHANGELOG.md
* Update smoke_test_fwd_sink.sh
* Update correct_test_fwd_sink.sh
* Update smoke_test_fwd_sink.sh
---------
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -86,21 +86,22 @@ struct GenericAttentionMask
|
||||
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
|
||||
|
||||
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
|
||||
: GenericAttentionMask(0, 0, y_total_, x_total_)
|
||||
: GenericAttentionMask(0, 0, 0, y_total_, x_total_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
|
||||
GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
|
||||
{
|
||||
}
|
||||
template <typename MaskCoordinates>
|
||||
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
|
||||
: y(mask_coord.at(number<0>{})),
|
||||
x(mask_coord.at(number<1>{})),
|
||||
y_total(mask_coord.at(number<2>{})),
|
||||
x_total(mask_coord.at(number<3>{}))
|
||||
sink(mask_coord.at(number<2>{})),
|
||||
y_total(mask_coord.at(number<3>{})),
|
||||
x_total(mask_coord.at(number<4>{}))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -141,6 +142,44 @@ struct GenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, 0, x_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
index_t tmp = max(-y + i_y + 1, 0);
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
index_t tmp = min(i_y + YTile - 1 + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
|
||||
if(x_start <= sink_seq_end && sink > 0)
|
||||
return ck_tile::make_tuple(0, 0, x_end);
|
||||
else
|
||||
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
|
||||
}
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
@@ -195,6 +234,30 @@ struct GenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
return i_x >= x_total;
|
||||
// no need to do min/max here, since i_x will never be < 0 or >= x_total
|
||||
index_t x_start = -y + i_y + 1;
|
||||
index_t x_end = min(i_y + x, x_total);
|
||||
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
|
||||
return false;
|
||||
else
|
||||
return i_x < x_start || i_x >= x_end;
|
||||
}
|
||||
else
|
||||
{
|
||||
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
|
||||
return false;
|
||||
else
|
||||
return i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
@@ -237,7 +300,7 @@ struct GenericAttentionMask
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y, x, sink;
|
||||
index_t y_total, x_total;
|
||||
};
|
||||
|
||||
@@ -260,21 +323,23 @@ struct SimplifiedGenericAttentionMask
|
||||
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
|
||||
|
||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
|
||||
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
|
||||
: SimplifiedGenericAttentionMask(0, 0, 0, y_total_, x_total_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
|
||||
SimplifiedGenericAttentionMask(
|
||||
index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
|
||||
{
|
||||
}
|
||||
template <typename MaskCoordinates>
|
||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
|
||||
: y(mask_coord.at(number<0>{})),
|
||||
x(mask_coord.at(number<1>{})),
|
||||
y_total(mask_coord.at(number<2>{})),
|
||||
x_total(mask_coord.at(number<3>{}))
|
||||
sink(mask_coord.at(number<2>{})),
|
||||
y_total(mask_coord.at(number<3>{})),
|
||||
x_total(mask_coord.at(number<4>{}))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -308,6 +373,38 @@ struct SimplifiedGenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, 0, x_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
index_t tmp = max(-y + i_y + 1, 0);
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
index_t tmp = min(i_y + YTile - 1 + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
|
||||
|
||||
if(x_start <= sink_seq_end && sink > 0)
|
||||
return ck_tile::make_tuple(0, 0, x_end);
|
||||
else
|
||||
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
|
||||
number<TileHeight> height,
|
||||
@@ -325,6 +422,29 @@ struct SimplifiedGenericAttentionMask
|
||||
ck_tile::min(origin_end, split_end));
|
||||
}
|
||||
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetSinkTileRangeAlongX(index_t i_y,
|
||||
number<TileHeight> height,
|
||||
number<TileWidth> width,
|
||||
index_t num_splits,
|
||||
index_t i_split) const
|
||||
{
|
||||
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
|
||||
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
|
||||
const index_t split_start = x_per_split * i_split; // 128
|
||||
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); // 256
|
||||
const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0;
|
||||
const index_t start = ck_tile::max(origin_start, split_start);
|
||||
const index_t end = ck_tile::min(origin_end, split_end);
|
||||
const bool is_first_intersecting_split =
|
||||
(split_start <= origin_start && split_end >= origin_start);
|
||||
const bool sink_in_range = (sink_seq_end <= start);
|
||||
|
||||
const index_t sink_offset =
|
||||
(is_first_intersecting_split && sink_in_range) ? sink_seq_end : 0;
|
||||
return ck_tile::make_tuple(sink_offset, start, end);
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
@@ -368,11 +488,22 @@ struct SimplifiedGenericAttentionMask
|
||||
{
|
||||
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
|
||||
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
|
||||
|
||||
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
return i_x >= x_total;
|
||||
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
|
||||
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
|
||||
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
|
||||
return false;
|
||||
else
|
||||
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
@@ -406,7 +537,7 @@ struct SimplifiedGenericAttentionMask
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y, x, sink;
|
||||
index_t y_total, x_total;
|
||||
};
|
||||
|
||||
@@ -620,6 +751,7 @@ static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask<Ma
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
|
||||
index_t right_size,
|
||||
index_t sink_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
bool is_top_left = true)
|
||||
@@ -637,7 +769,21 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
|
||||
index_t x = 1 + right_size + x_tmp;
|
||||
index_t y = 1 + left_size + y_tmp;
|
||||
|
||||
return ck_tile::make_tuple(y, x, y_total, x_total);
|
||||
return ck_tile::make_tuple(y, x, sink_size, y_total, x_total);
|
||||
}
|
||||
|
||||
template <typename MaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_generic_attention_mask_from_lr_window(index_t left_size,
|
||||
index_t right_size,
|
||||
index_t sink_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, sink_size, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), sink_size, y_total, x_total};
|
||||
}
|
||||
|
||||
template <typename MaskType>
|
||||
@@ -649,7 +795,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
|
||||
left_size, right_size, 0, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), 0, y_total, x_total};
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -162,6 +162,17 @@ struct StandardAttention
|
||||
{
|
||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||
}
|
||||
|
||||
template <typename Params>
|
||||
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
uint32_t qo_idx,
|
||||
uint32_t kv_idx,
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool UseExp2 = false>
|
||||
@@ -224,6 +235,17 @@ struct LogitsSoftCap
|
||||
{
|
||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||
}
|
||||
|
||||
template <typename Params>
|
||||
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
uint32_t qo_idx,
|
||||
uint32_t kv_idx,
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr uint32_t CUSTOM_MASK = 1U;
|
||||
@@ -297,6 +319,17 @@ struct ComposedAttention
|
||||
{
|
||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||
}
|
||||
|
||||
template <typename Params>
|
||||
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
uint32_t qo_idx,
|
||||
uint32_t kv_idx,
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -200,7 +200,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
struct FmhaFwdMaskKargs
|
||||
{
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
@@ -356,6 +356,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
@@ -418,6 +419,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -497,6 +499,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
@@ -557,6 +560,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -1008,6 +1012,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.sink_size,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
|
||||
@@ -58,6 +58,7 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
|
||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
@@ -155,7 +156,7 @@ struct FmhaFwdKernel
|
||||
struct FmhaFwdMaskKargs
|
||||
{
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
@@ -335,6 +336,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
@@ -393,6 +395,7 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -481,6 +484,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
@@ -529,6 +533,7 @@ struct FmhaFwdKernel
|
||||
batch_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
@@ -580,6 +585,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
@@ -628,6 +634,7 @@ struct FmhaFwdKernel
|
||||
batch_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
@@ -673,6 +680,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
@@ -732,6 +740,7 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -817,6 +826,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
@@ -861,6 +871,7 @@ struct FmhaFwdKernel
|
||||
nhead_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
@@ -908,6 +919,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
@@ -952,6 +964,7 @@ struct FmhaFwdKernel
|
||||
nhead_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
@@ -1443,6 +1456,7 @@ struct FmhaFwdKernel
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.sink_size,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
@@ -2182,6 +2196,7 @@ struct FmhaFwdKernel
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.sink_size,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
|
||||
@@ -55,6 +55,7 @@ struct FmhaFwdPagedKVKernel
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
@@ -101,7 +102,7 @@ struct FmhaFwdPagedKVKernel
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -189,7 +190,7 @@ struct FmhaFwdPagedKVKernel
|
||||
struct FmhaFwdMaskKargs
|
||||
{
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
@@ -326,6 +327,7 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
@@ -379,6 +381,7 @@ struct FmhaFwdPagedKVKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -453,6 +456,7 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
{
|
||||
return MakeKargsImpl(q_ptr,
|
||||
@@ -495,6 +499,7 @@ struct FmhaFwdPagedKVKernel
|
||||
batch_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type);
|
||||
}
|
||||
|
||||
@@ -536,6 +541,7 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q)
|
||||
{
|
||||
@@ -590,6 +596,7 @@ struct FmhaFwdPagedKVKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -660,6 +667,7 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q)
|
||||
{
|
||||
@@ -699,6 +707,7 @@ struct FmhaFwdPagedKVKernel
|
||||
batch_stride_v,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
min_seqlen_q);
|
||||
}
|
||||
@@ -1257,6 +1266,7 @@ struct FmhaFwdPagedKVKernel
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.sink_size,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
|
||||
@@ -51,6 +51,7 @@ struct FmhaFwdSplitKVKernel
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||
static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink;
|
||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
|
||||
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
@@ -101,7 +102,7 @@ struct FmhaFwdSplitKVKernel
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
|
||||
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
|
||||
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -198,7 +199,7 @@ struct FmhaFwdSplitKVKernel
|
||||
struct MaskKargs
|
||||
{
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
@@ -325,6 +326,7 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t split_stride_o_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
@@ -384,6 +386,7 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
@@ -451,6 +454,7 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t split_stride_o_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
@@ -508,6 +512,7 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
@@ -994,6 +999,7 @@ struct FmhaFwdSplitKVKernel
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.sink_size,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
|
||||
@@ -57,6 +57,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -228,10 +229,22 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
else
|
||||
{
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
@@ -255,7 +268,6 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
// k_dram_block_window
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
@@ -274,27 +286,36 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
return physical_seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
|
||||
? aligned_physical_seqlen_k_start
|
||||
: 0;
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
|
||||
num_sink_loop;
|
||||
|
||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||
k_dram_block_window_lengths, {kv_load_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
const index_t bias_n_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return kv_load_start;
|
||||
else
|
||||
return logical_seqlen_k_start -
|
||||
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
|
||||
}();
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}),
|
||||
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
aligned_physical_seqlen_k_start)}, // M/N
|
||||
{bias_origin.at(number<0>{}), bias_n_offset},
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
// v_dram_window
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
|
||||
// prefetch K tile
|
||||
@@ -321,9 +342,16 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
|
||||
const auto k_move_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
|
||||
else
|
||||
return kN0;
|
||||
}();
|
||||
auto physical_next_block_id_k =
|
||||
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}));
|
||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
|
||||
auto physical_next_block_id_v = amd_wave_read_first_lane(
|
||||
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
||||
|
||||
@@ -442,7 +470,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
move_tile_window(bias_dram_window, {0, k_move_offset});
|
||||
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
@@ -474,14 +502,29 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask_func(row, col - kv_l2p_offset);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask([&](auto row, auto col) {
|
||||
return mask.IsOutOfSinkBound(row, col);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask(
|
||||
[&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -647,7 +690,12 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
}
|
||||
// move K tile windows
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
|
||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
|
||||
physical_next_block_id_v =
|
||||
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
@@ -57,6 +57,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -256,11 +257,23 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
else
|
||||
{
|
||||
auto [start, end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
|
||||
// check early exit if no work to do
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
{
|
||||
const index_t logical_num_total_loop =
|
||||
@@ -304,24 +317,33 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
return physical_seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
|
||||
? aligned_physical_seqlen_k_start
|
||||
: 0;
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
|
||||
num_sink_loop;
|
||||
|
||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||
k_dram_block_window_lengths, {kv_load_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
const index_t bias_n_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return kv_load_start;
|
||||
else
|
||||
return logical_seqlen_k_start -
|
||||
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
|
||||
}();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}),
|
||||
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
aligned_physical_seqlen_k_start)}, // M/N
|
||||
{bias_origin.at(number<0>{}), bias_n_offset},
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// store Q into LDS
|
||||
@@ -369,7 +391,13 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
{
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
|
||||
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
|
||||
const auto k_move_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
|
||||
else
|
||||
return kN0;
|
||||
}();
|
||||
// load the second tile of the first iteration
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
@@ -494,7 +522,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
move_tile_window(bias_dram_window, {0, k_move_offset});
|
||||
|
||||
/// TODO: only check in first/last iteration without increasing code size
|
||||
if constexpr(kHasUnevenSplits)
|
||||
@@ -505,7 +533,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&,
|
||||
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
|
||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
if constexpr(kIsPagedKV)
|
||||
@@ -530,12 +558,26 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||
});
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask_func(row, col - kv_l2p_offset);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask(
|
||||
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -546,7 +588,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
{
|
||||
// move K tile windows
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0});
|
||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0});
|
||||
|
||||
k_dram_window = make_tile_window(
|
||||
k_dram_block_window,
|
||||
@@ -742,6 +784,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
// moving k_dram_window is an in-page-block operation, so there is
|
||||
// no need to invoke k_page_block_navigator.move_tile_window() here.
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0});
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
}
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
@@ -56,6 +56,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -229,9 +230,23 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
else
|
||||
{
|
||||
auto [start, end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
@@ -274,24 +289,35 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
return physical_seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
|
||||
? aligned_physical_seqlen_k_start
|
||||
: 0;
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
|
||||
num_sink_loop;
|
||||
|
||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||
k_dram_block_window_lengths, {kv_load_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
const index_t bias_n_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return kv_load_start;
|
||||
else
|
||||
return logical_seqlen_k_start -
|
||||
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
|
||||
}();
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}),
|
||||
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
aligned_physical_seqlen_k_start)}, // M/N
|
||||
{bias_origin.at(number<0>{}), bias_n_offset},
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
@@ -320,9 +346,18 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
|
||||
|
||||
const auto k_move_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
|
||||
else
|
||||
return kN0;
|
||||
}();
|
||||
|
||||
auto physical_next_block_id_k =
|
||||
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}));
|
||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
|
||||
auto physical_next_block_id_v = amd_wave_read_first_lane(
|
||||
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
||||
|
||||
@@ -441,7 +476,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
move_tile_window(bias_dram_window, {0, k_move_offset});
|
||||
|
||||
/// TODO: only check in first/last iteration without increasing code size
|
||||
if constexpr(kHasUnevenSplits)
|
||||
@@ -452,7 +487,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&,
|
||||
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
|
||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
if constexpr(kIsPagedKV)
|
||||
@@ -477,12 +512,26 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||
});
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask_func(row, col - kv_l2p_offset);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask(
|
||||
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -647,7 +696,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
// move K tile windows
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
|
||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
|
||||
physical_next_block_id_v =
|
||||
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
@@ -62,6 +62,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Traits::QScaleEnum;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static constexpr bool kHasSink = Traits::kHasSink;
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
@@ -114,6 +115,7 @@ struct BlockFmhaFwdPagedKVPipelineProblem
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static constexpr bool kHasSink = Traits::kHasSink;
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
@@ -167,6 +169,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem
|
||||
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
|
||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static constexpr bool kHasSink = Traits::kHasSink;
|
||||
};
|
||||
|
||||
// extract tile size attributes to remove dependency on traits
|
||||
|
||||
@@ -57,6 +57,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
|
||||
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
|
||||
@@ -233,10 +234,26 @@ struct BlockFmhaPipelineQRKSVS
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
else
|
||||
{
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
|
||||
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
const auto num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
@@ -262,22 +279,22 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
{kv_load_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
randval_dram_block_window_tmp, kv_load_start);
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
@@ -450,6 +467,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == 0)
|
||||
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
@@ -460,17 +482,34 @@ struct BlockFmhaPipelineQRKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask_func(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -580,11 +619,23 @@ struct BlockFmhaPipelineQRKSVS
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
// K and dropout use the same address in LDS, finish loading from k_lds_window by
|
||||
// gemm_0 to reuse LDS.
|
||||
block_sync_lds();
|
||||
auto randval_ptr = reinterpret_cast<char*>(smem_ptr);
|
||||
|
||||
index_t seq_offset = [&]() {
|
||||
if constexpr(!kHasSink)
|
||||
return seqlen_k_start + i_total_loops * kN0;
|
||||
|
||||
const bool in_sink_phase = (num_sink_loop > i_total_loops);
|
||||
if(i_total_loops == num_sink_loop)
|
||||
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
|
||||
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
|
||||
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
|
||||
}();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
|
||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
@@ -636,6 +687,14 @@ struct BlockFmhaPipelineQRKSVS
|
||||
});
|
||||
}
|
||||
// move K tile windows
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == 0)
|
||||
{
|
||||
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
|
||||
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
}
|
||||
}
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
// tail
|
||||
{
|
||||
|
||||
@@ -62,6 +62,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -277,11 +278,26 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
clear_tile(l);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
else
|
||||
{
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
const auto num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
@@ -309,7 +325,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
{kv_load_start, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
@@ -332,16 +348,16 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
randval_dram_block_window_tmp, kv_load_start);
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// prefetch K tile
|
||||
@@ -478,6 +494,11 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == 0)
|
||||
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
@@ -489,17 +510,34 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask_func(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -647,11 +685,21 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
auto randval_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
index_t seq_offset = [&]() {
|
||||
if constexpr(!kHasSink)
|
||||
return seqlen_k_start + i_total_loops * kN0;
|
||||
|
||||
const bool in_sink_phase = (num_sink_loop > i_total_loops);
|
||||
if(i_total_loops == num_sink_loop)
|
||||
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
|
||||
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
|
||||
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
|
||||
}();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
randval_ptr,
|
||||
seqlen_k_start + i_total_loops * kN0,
|
||||
p_compute,
|
||||
randval_dram_window);
|
||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
@@ -717,8 +765,16 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
i_total_loops++;
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
// move K tile windows
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == 0)
|
||||
{
|
||||
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
|
||||
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
}
|
||||
}
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
|
||||
@@ -69,6 +69,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasUnevenSplits = true;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
|
||||
@@ -20,8 +20,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
|
||||
bool kHasSink_ = false>
|
||||
struct TileFmhaTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
@@ -36,6 +37,7 @@ struct TileFmhaTraits
|
||||
static constexpr auto QScaleEnum = QScaleEnum_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
@@ -65,8 +67,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
|
||||
bool kIsPagedKV_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
|
||||
bool kHasSink_ = false>
|
||||
struct TileFmhaFwdPagedKVTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
@@ -81,6 +84,7 @@ struct TileFmhaFwdPagedKVTraits
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
@@ -95,7 +99,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kIsPagedKV_,
|
||||
bool kHasUnevenSplits_,
|
||||
bool kMergeNumHeadGroupsSeqLenQ_ = false,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kHasSink_ = false>
|
||||
struct TileFmhaFwdSplitKVTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
@@ -112,6 +117,7 @@ struct TileFmhaFwdSplitKVTraits
|
||||
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
|
||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
|
||||
Reference in New Issue
Block a user