Make sure we always start reading complete tile

This commit is contained in:
PoYen, Chen
2024-08-06 03:13:57 +00:00
parent 4fed268723
commit f9e2bafd10
2 changed files with 37 additions and 20 deletions

View File

@@ -560,7 +560,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<VDataType> v_host(
USE_PAGED_VCACHE && 0 < page_block_size
0 < page_block_size
? (is_v_rowmajor
? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_v)
: get_lengths(i_perm, max_num_blocks, nhead_k, hdim_v, page_block_size))
@@ -884,9 +884,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return USE_PAGED_VCACHE && 0 < page_block_size
? (i_perm ? page_block_size : nhead_k * page_block_size)
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size)
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k);
@@ -899,13 +898,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return USE_PAGED_VCACHE && 0 < page_block_size
? (i_perm ? page_block_size * hdim_v : hdim_v)
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v)
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
else
return USE_PAGED_VCACHE && 0 < page_block_size
? (i_perm ? hdim_v * page_block_size : page_block_size)
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size)
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
}();
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
@@ -920,8 +917,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
(0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
: (nhead_k * shape_seqlen_k * hdim_q));
const ck_tile::index_t batch_stride_v =
(USE_PAGED_VCACHE && 0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
: (nhead_k * hdim_v * shape_seqlen_k));
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
: (nhead_k * hdim_v * shape_seqlen_k));
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
@@ -1128,7 +1125,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
});
}
#endif
if (USE_PAGED_VCACHE && 0 < page_block_size) {
if (0 < page_block_size) {
if (is_v_rowmajor) {
if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) {

View File

@@ -218,7 +218,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// make sure the first tile is completely located in page-block
const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] {
if constexpr(kIsPagedKV)
{
return kN0 * integer_divide_floor(seqlen_k_start_, kN0);
}
else
{
return seqlen_k_start_;
}
}();
const auto num_total_loop =
integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
@@ -242,8 +254,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
auto [i_block0, k_dram_block_window] =
k_tile_navigator.make_tile_window(k_dram_block_window_tmp, {seqlen_k_start, 0});
auto [i_block0, k_dram_block_window] = k_tile_navigator.make_tile_window(
k_dram_block_window_tmp, {adjusted_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
@@ -257,7 +269,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
auto [i_block1, v_dram_window] = v_tile_navigator.make_tile_window(
v_dram_block_window_tmp,
{0, seqlen_k_start}, // TODO: hdim split?
{0, adjusted_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
@@ -384,17 +396,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in last iteration without increasing code size
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
{
const auto k_origin = k_tile_navigator.to_global_window_origin(
i_block0, k_dram_block_window.get_window_origin());
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) {
[&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end](
auto tile_idx) {
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return seqlen_k_end_ <= col;
if constexpr(kIsPagedKV)
{
return col < seqlen_k_start_ || seqlen_k_end_ <= col;
}
else
{
return seqlen_k_end_ <= col;
}
});
}