Make sure we always start reading complete tile

This commit is contained in:
PoYen, Chen
2024-08-06 03:13:57 +00:00
parent 4fed268723
commit f9e2bafd10
2 changed files with 37 additions and 20 deletions

View File

@@ -218,7 +218,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// make sure the first tile is completely located in page-block
const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] {
if constexpr(kIsPagedKV)
{
return kN0 * integer_divide_floor(seqlen_k_start_, kN0);
}
else
{
return seqlen_k_start_;
}
}();
const auto num_total_loop =
integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
@@ -242,8 +254,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
auto [i_block0, k_dram_block_window] =
k_tile_navigator.make_tile_window(k_dram_block_window_tmp, {seqlen_k_start, 0});
auto [i_block0, k_dram_block_window] = k_tile_navigator.make_tile_window(
k_dram_block_window_tmp, {adjusted_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
@@ -257,7 +269,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
auto [i_block1, v_dram_window] = v_tile_navigator.make_tile_window(
v_dram_block_window_tmp,
{0, seqlen_k_start}, // TODO: hdim split?
{0, adjusted_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
@@ -384,17 +396,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in last iteration without increasing code size
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
{
const auto k_origin = k_tile_navigator.to_global_window_origin(
i_block0, k_dram_block_window.get_window_origin());
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) {
[&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end](
auto tile_idx) {
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return seqlen_k_end_ <= col;
if constexpr(kIsPagedKV)
{
return col < seqlen_k_start_ || seqlen_k_end_ <= col;
}
else
{
return seqlen_k_end_ <= col;
}
});
}