mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK_TILE] Fix incorrect computation of group mode PagedAttention (#1688)
* Allow getting batch size from splitkv tile partitioner * Fix wrong paged-kvcache impl for group mode * Fix wrong example code for page-kvcache * Undo changes in fmha_fwd.cpp * Always use 2D block table * Add is_gappy kernel argument for paged-kvcache The is_gappy argument is used for differentiating seqstart_k_ptr usage in flash-attention & xformers * Remove out-of-date comments * Remove no-longer used method * Fix wrong # page-block calculation * Fix wrong comment --------- Co-authored-by: Qianfeng <qianfeng.zhang@amd.com>
This commit is contained in:
@@ -143,6 +143,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -211,16 +212,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
{
|
||||
const index_t original_num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
if(original_num_total_loop <= 0)
|
||||
const index_t logical_num_total_loop =
|
||||
integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
|
||||
if(logical_num_total_loop <= 0)
|
||||
{
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -239,33 +240,41 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
}
|
||||
|
||||
// make sure the first tile is completely located in page-block
|
||||
const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return kN0 * integer_divide_floor(seqlen_k_start_, kN0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
// divisible by kN0)
|
||||
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
|
||||
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
|
||||
const index_t aligned_physical_seqlen_k_start =
|
||||
[&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return physical_seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0);
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
|
||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0});
|
||||
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
|
||||
{bias_origin.at(number<0>{}),
|
||||
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
aligned_physical_seqlen_k_start)}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, adjusted_seqlen_k_start}, // TODO: hdim split?
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
@@ -379,7 +388,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale_s;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
// position_encoding accept only logical coordinates, do conversion here
|
||||
position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -397,29 +407,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end](
|
||||
auto tile_idx) {
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return col < seqlen_k_start_ || seqlen_k_end_ <= col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return seqlen_k_end_ <= col;
|
||||
}
|
||||
});
|
||||
set_tile_if(
|
||||
s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&,
|
||||
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
// mask accept only logical coordinates, do conversion here
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}) - kv_l2p_offset,
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
@@ -428,7 +440,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -659,6 +671,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
@@ -681,6 +694,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user