diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 3cd9fab8f2..e32f6f33a6 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -496,7 +496,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - const ck_tile::index_t max_num_blocks = + const ck_tile::index_t max_num_page_blocks = (0 < page_block_size ? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size)) : 0); @@ -546,8 +546,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor k_host( - 0 < page_block_size ? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_q) - : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + 0 < page_block_size + ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) + : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode ck_tile::HostTensor knew_host( 0 < seqlen_knew @@ -556,8 +557,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor v_host( 0 < page_block_size ? (is_v_rowmajor - ? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_v) - : get_lengths(i_perm, max_num_blocks, nhead_k, hdim_v, page_block_size)) + ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v) + : get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size)) : (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); ck_tile::HostTensor vnew_host( @@ -601,7 +602,7 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{1, 1, 1, 1}); ck_tile::HostTensor block_table_host( - 0 < page_block_size ? std::array{batch, max_num_blocks / batch} + 0 < page_block_size ? std::array{batch, max_num_page_blocks / batch} : std::array{1, 1}); if(init_method == "ui" || init_method == "0") @@ -821,7 +822,7 @@ bool run(const ck_tile::ArgParser& arg_parser) (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) : (nhead_k * hdim_v * shape_seqlen_k)); const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); - const ck_tile::index_t batch_stride_block_table = (max_num_blocks / batch); + const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); return fmha_fwd_appendkv_args{ q_buf.GetDeviceBuffer(), @@ -953,7 +954,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - const ck_tile::index_t batch_stride_block_table = (max_num_blocks / batch); + const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q); const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);