mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 23:38:11 +00:00
Make sure we always start reading complete tile
This commit is contained in:
@@ -218,7 +218,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
// make sure the first tile is completely located in page-block
|
||||
const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return kN0 * integer_divide_floor(seqlen_k_start_, kN0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const auto num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
|
||||
@@ -242,8 +254,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
}
|
||||
|
||||
auto [i_block0, k_dram_block_window] =
|
||||
k_tile_navigator.make_tile_window(k_dram_block_window_tmp, {seqlen_k_start, 0});
|
||||
auto [i_block0, k_dram_block_window] = k_tile_navigator.make_tile_window(
|
||||
k_dram_block_window_tmp, {adjusted_seqlen_k_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window = make_tile_window(
|
||||
@@ -257,7 +269,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
|
||||
auto [i_block1, v_dram_window] = v_tile_navigator.make_tile_window(
|
||||
v_dram_block_window_tmp,
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
{0, adjusted_seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
@@ -384,17 +396,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
|
||||
/// TODO: only check in last iteration without increasing code size
|
||||
/// TODO: only check in first/last iteration without increasing code size
|
||||
if constexpr(kHasUnevenSplits)
|
||||
{
|
||||
const auto k_origin = k_tile_navigator.to_global_window_origin(
|
||||
i_block0, k_dram_block_window.get_window_origin());
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) {
|
||||
[&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end](
|
||||
auto tile_idx) {
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return seqlen_k_end_ <= col;
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return col < seqlen_k_start_ || seqlen_k_end_ <= col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return seqlen_k_end_ <= col;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user