mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Make sure we always start reading complete tile
This commit is contained in:
@@ -560,7 +560,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<VDataType> v_host(
|
||||
USE_PAGED_VCACHE && 0 < page_block_size
|
||||
0 < page_block_size
|
||||
? (is_v_rowmajor
|
||||
? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_v)
|
||||
: get_lengths(i_perm, max_num_blocks, nhead_k, hdim_v, page_block_size))
|
||||
@@ -884,9 +884,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return USE_PAGED_VCACHE && 0 < page_block_size
|
||||
? (i_perm ? page_block_size : nhead_k * page_block_size)
|
||||
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size)
|
||||
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
||||
@@ -899,13 +898,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return USE_PAGED_VCACHE && 0 < page_block_size
|
||||
? (i_perm ? page_block_size * hdim_v : hdim_v)
|
||||
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
||||
return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v)
|
||||
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
||||
else
|
||||
return USE_PAGED_VCACHE && 0 < page_block_size
|
||||
? (i_perm ? hdim_v * page_block_size : page_block_size)
|
||||
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
|
||||
return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size)
|
||||
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
@@ -920,8 +917,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
|
||||
: (nhead_k * shape_seqlen_k * hdim_q));
|
||||
const ck_tile::index_t batch_stride_v =
|
||||
(USE_PAGED_VCACHE && 0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
||||
: (nhead_k * hdim_v * shape_seqlen_k));
|
||||
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
||||
: (nhead_k * hdim_v * shape_seqlen_k));
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
|
||||
@@ -1128,7 +1125,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
});
|
||||
}
|
||||
#endif
|
||||
if (USE_PAGED_VCACHE && 0 < page_block_size) {
|
||||
if (0 < page_block_size) {
|
||||
if (is_v_rowmajor) {
|
||||
if(i_perm) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
|
||||
@@ -218,7 +218,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
// make sure the first tile is completely located in page-block
|
||||
const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return kN0 * integer_divide_floor(seqlen_k_start_, kN0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const auto num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
|
||||
@@ -242,8 +254,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
}
|
||||
|
||||
auto [i_block0, k_dram_block_window] =
|
||||
k_tile_navigator.make_tile_window(k_dram_block_window_tmp, {seqlen_k_start, 0});
|
||||
auto [i_block0, k_dram_block_window] = k_tile_navigator.make_tile_window(
|
||||
k_dram_block_window_tmp, {adjusted_seqlen_k_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window = make_tile_window(
|
||||
@@ -257,7 +269,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
|
||||
auto [i_block1, v_dram_window] = v_tile_navigator.make_tile_window(
|
||||
v_dram_block_window_tmp,
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
{0, adjusted_seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
@@ -384,17 +396,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
|
||||
/// TODO: only check in last iteration without increasing code size
|
||||
/// TODO: only check in first/last iteration without increasing code size
|
||||
if constexpr(kHasUnevenSplits)
|
||||
{
|
||||
const auto k_origin = k_tile_navigator.to_global_window_origin(
|
||||
i_block0, k_dram_block_window.get_window_origin());
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) {
|
||||
[&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end](
|
||||
auto tile_idx) {
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return seqlen_k_end_ <= col;
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return col < seqlen_k_start_ || seqlen_k_end_ <= col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return seqlen_k_end_ <= col;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user