Use supplement_array_by_last_element() in example to simplify the codes

This commit is contained in:
Qianfeng Zhang
2025-11-01 16:20:38 +00:00
parent 10133e5d51
commit 80e08b6efe

View File

@@ -155,6 +155,18 @@ static std::vector<int> get_integers_from_string(std::string lengthsStr)
return (lengths);
};
template <typename T>
void supplement_array_by_last_element(std::vector<T>& arr, int target_num_elements)
{
if(static_cast<int>(arr.size()) < target_num_elements)
{
T last_val = arr.back();
for(int i = arr.size(); i < target_num_elements; i++)
arr.push_back(last_val);
};
};
static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdParams& param)
{
if(param.is_jagged)
@@ -275,23 +287,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(is_jagged)
{
// supplement seq_lengths using the last input value if user-provided lengths not enough
if(static_cast<int>(seq_lengths_q.size()) < num_batch)
{
auto last_len = seq_lengths_q.back();
for(int i = seq_lengths_q.size(); i < num_batch; i++)
seq_lengths_q.push_back(last_len);
};
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
supplement_array_by_last_element(seq_lengths_q, num_batch);
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
if(static_cast<int>(seq_lengths_kv.size()) < num_batch)
{
auto last_len = seq_lengths_kv.back();
for(int i = seq_lengths_kv.size(); i < num_batch; i++)
seq_lengths_kv.push_back(last_len);
};
supplement_array_by_last_element(seq_lengths_kv, num_batch);
// only consider num_batch values even if more values are provided by the user
for(int i = 0; i < num_batch; i++)