mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Rename 'max_num_blocks' to 'max_num_page_blocks'
This commit is contained in:
@@ -496,7 +496,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
const ck_tile::index_t max_num_blocks =
|
||||
const ck_tile::index_t max_num_page_blocks =
|
||||
(0 < page_block_size
|
||||
? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size))
|
||||
: 0);
|
||||
@@ -546,8 +546,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
0 < page_block_size ? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
|
||||
ck_tile::HostTensor<KDataType> knew_host(
|
||||
0 < seqlen_knew
|
||||
@@ -556,8 +557,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<VDataType> v_host(
|
||||
0 < page_block_size
|
||||
? (is_v_rowmajor
|
||||
? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_v)
|
||||
: get_lengths(i_perm, max_num_blocks, nhead_k, hdim_v, page_block_size))
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v)
|
||||
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
|
||||
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
|
||||
ck_tile::HostTensor<VDataType> vnew_host(
|
||||
@@ -601,7 +602,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
|
||||
ck_tile::HostTensor<int32_t> block_table_host(
|
||||
0 < page_block_size ? std::array<ck_tile::index_t, 2>{batch, max_num_blocks / batch}
|
||||
0 < page_block_size ? std::array<ck_tile::index_t, 2>{batch, max_num_page_blocks / batch}
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
@@ -821,7 +822,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
||||
: (nhead_k * hdim_v * shape_seqlen_k));
|
||||
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_blocks / batch);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
|
||||
return fmha_fwd_appendkv_args{
|
||||
q_buf.GetDeviceBuffer(),
|
||||
@@ -953,7 +954,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_blocks / batch);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
|
||||
|
||||
Reference in New Issue
Block a user