Add max len k to UA argument structure

This commit is contained in:
Damien Lejeune
2026-04-23 15:22:09 +00:00
parent ce751cf74d
commit 977af0e511
3 changed files with 5 additions and 2 deletions

View File

@@ -420,6 +420,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
};
ck_tile::index_t max_kv_len = max_element(eff_kv_lens);
args.max_seqlen_k = max_kv_len;
ck_tile::index_t max_num_blocks_per_seq =
(max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size;

View File

@@ -80,7 +80,8 @@ std::ostream& operator<<(std::ostream& stream, const unified_attention_args& arg
// stream << ", query_start_len_ptr=";
// write_ptr(stream, static_cast<const void*>(args.query_start_len_ptr));
return stream << ", num_seqs=" << args.num_seqs
<< ", max_seqlen_q=" << args.max_seqlen_q << " }";
<< ", max_seqlen_q=" << args.max_seqlen_q
<< ", max_seqlen_k=" << args.max_seqlen_k << " }";
}
// Helper macro to reduce dispatch boilerplate.

View File

@@ -66,7 +66,8 @@ struct unified_attention_args
const int32_t* query_start_len_ptr; // [num_seqs+1]
index_t num_seqs; // number of batches for q
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
index_t max_seqlen_k = 0; // max KV length across seqs in seq_lens (0 = unknown / not set)
};
std::ostream& operator<<(std::ostream& stream,