mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Add max len k to UA argument structure
This commit is contained in:
@@ -420,6 +420,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
};
|
||||
|
||||
ck_tile::index_t max_kv_len = max_element(eff_kv_lens);
|
||||
args.max_seqlen_k = max_kv_len;
|
||||
|
||||
ck_tile::index_t max_num_blocks_per_seq =
|
||||
(max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size;
|
||||
|
||||
@@ -80,7 +80,8 @@ std::ostream& operator<<(std::ostream& stream, const unified_attention_args& arg
|
||||
// stream << ", query_start_len_ptr=";
|
||||
// write_ptr(stream, static_cast<const void*>(args.query_start_len_ptr));
|
||||
return stream << ", num_seqs=" << args.num_seqs
|
||||
<< ", max_seqlen_q=" << args.max_seqlen_q << " }";
|
||||
<< ", max_seqlen_q=" << args.max_seqlen_q
|
||||
<< ", max_seqlen_k=" << args.max_seqlen_k << " }";
|
||||
}
|
||||
|
||||
// Helper macro to reduce dispatch boilerplate.
|
||||
|
||||
@@ -66,7 +66,8 @@ struct unified_attention_args
|
||||
const int32_t* query_start_len_ptr; // [num_seqs+1]
|
||||
|
||||
index_t num_seqs; // number of batches for q
|
||||
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
|
||||
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
|
||||
index_t max_seqlen_k = 0; // max KV length across seqs in seq_lens (0 = unknown / not set)
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream,
|
||||
|
||||
Reference in New Issue
Block a user