mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Use supplement_array_by_last_element() in example to simplify the codes
This commit is contained in:
@@ -155,6 +155,18 @@ static std::vector<int> get_integers_from_string(std::string lengthsStr)
|
||||
return (lengths);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void supplement_array_by_last_element(std::vector<T>& arr, int target_num_elements)
|
||||
{
|
||||
if(static_cast<int>(arr.size()) < target_num_elements)
|
||||
{
|
||||
T last_val = arr.back();
|
||||
|
||||
for(int i = arr.size(); i < target_num_elements; i++)
|
||||
arr.push_back(last_val);
|
||||
};
|
||||
};
|
||||
|
||||
static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdParams& param)
|
||||
{
|
||||
if(param.is_jagged)
|
||||
@@ -275,23 +287,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
// supplement seq_lengths using the last input value if user-provided lengths not enough
|
||||
if(static_cast<int>(seq_lengths_q.size()) < num_batch)
|
||||
{
|
||||
auto last_len = seq_lengths_q.back();
|
||||
|
||||
for(int i = seq_lengths_q.size(); i < num_batch; i++)
|
||||
seq_lengths_q.push_back(last_len);
|
||||
};
|
||||
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
|
||||
supplement_array_by_last_element(seq_lengths_q, num_batch);
|
||||
|
||||
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
|
||||
if(static_cast<int>(seq_lengths_kv.size()) < num_batch)
|
||||
{
|
||||
auto last_len = seq_lengths_kv.back();
|
||||
|
||||
for(int i = seq_lengths_kv.size(); i < num_batch; i++)
|
||||
seq_lengths_kv.push_back(last_len);
|
||||
};
|
||||
supplement_array_by_last_element(seq_lengths_kv, num_batch);
|
||||
|
||||
// only consider num_batch values even if more values are provided by the user
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
|
||||
Reference in New Issue
Block a user