From 80e08b6efea852685333b7965760702dedb18544 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 1 Nov 2025 16:20:38 +0000 Subject: [PATCH] Use supplement_array_by_last_element() in example to simplify the codes --- .../example_hstu_attention.cpp | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index ea4c6b1165..1354ebd50a 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -155,6 +155,18 @@ static std::vector get_integers_from_string(std::string lengthsStr) return (lengths); }; +template +void supplement_array_by_last_element(std::vector& arr, int target_num_elements) +{ + if(static_cast(arr.size()) < target_num_elements) + { + T last_val = arr.back(); + + for(int i = arr.size(); i < target_num_elements; i++) + arr.push_back(last_val); + }; +}; + static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdParams& param) { if(param.is_jagged) @@ -275,23 +287,11 @@ bool run(const ck_tile::ArgParser& arg_parser) if(is_jagged) { - // supplement seq_lengths using the last input value if user-provided lengths not enough - if(static_cast(seq_lengths_q.size()) < num_batch) - { - auto last_len = seq_lengths_q.back(); - - for(int i = seq_lengths_q.size(); i < num_batch; i++) - seq_lengths_q.push_back(last_len); - }; + // supplement seq_lengths_q using the last input value if user-provided lengths not enough + supplement_array_by_last_element(seq_lengths_q, num_batch); // supplement seq_lengths_kv using the last input value if user-provided lengths not enough - if(static_cast(seq_lengths_kv.size()) < num_batch) - { - auto last_len = seq_lengths_kv.back(); - - for(int i = seq_lengths_kv.size(); i < num_batch; i++) - seq_lengths_kv.push_back(last_len); - }; + supplement_array_by_last_element(seq_lengths_kv, num_batch); // only consider num_batch values even if more values are provided by the user for(int i = 0; i < num_batch; i++)