diff --git a/example/ck_tile/18_hstu_attention/README.md b/example/ck_tile/18_hstu_attention/README.md index b9f9dafe89..3f5ff84f69 100644 --- a/example/ck_tile/18_hstu_attention/README.md +++ b/example/ck_tile/18_hstu_attention/README.md @@ -44,7 +44,7 @@ .insert("nhead", "4", "number of heads") .insert("hdim_qk", "64", "headdim size of Q/K") .insert("hdim_v", "64", "headdim size of V/O") - .insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor") + .insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len") .insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention") .insert("causal", "1", "enable causal mask or not") .insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable") diff --git a/example/ck_tile/18_hstu_attention/bench_jagged_causal.sh b/example/ck_tile/18_hstu_attention/bench_jagged_causal.sh index fe195234f3..73c2c894c9 100644 --- a/example/ck_tile/18_hstu_attention/bench_jagged_causal.sh +++ b/example/ck_tile/18_hstu_attention/bench_jagged_causal.sh @@ -10,17 +10,16 @@ num_batch=32 num_head=4 target=20 -add_target() +add_comma() { x=$* y="" for len in $x; do - new_len=$(($len + $target)); if test -z $y; then - y="$new_len" + y="$len" else - y="$y,$new_len" + y="$y,$len" fi; done @@ -33,11 +32,11 @@ sl4096="1497 2516 3179 2891 190 3572 640 3025 464 1824 712 1519 2727 2621 1135 7 sl8192="4571 3202 270 1540 8169 3365 6055 7181 2942 4213 2717 3593 7748 4646 5502 4489 6525 2481 7397 2983 5667 1003 7926 3659 6129 6647 3758 6244 4175 2327 849 5261" sl16384="6956 7177 338 13755 10382 13392 10150 15592 15929 5256 6825 3804 5197 13415 14099 12418 13772 13659 5998 3715 9862 9183 11826 12964 6041 6712 12846 475 4672 7690 12280 10175" -s_sl1024=`add_target $sl1024` -s_sl2048=`add_target $sl2048` -s_sl4096=`add_target $sl4096` -s_sl8192=`add_target $sl8192` -s_sl16384=`add_target $sl16384` +s_sl1024=`add_comma $sl1024` +s_sl2048=`add_comma $sl2048` +s_sl4096=`add_comma $sl4096` +s_sl8192=`add_comma $sl8192` +s_sl16384=`add_comma $sl16384` set -x diff --git a/example/ck_tile/18_hstu_attention/bench_jagged_causal_local.sh b/example/ck_tile/18_hstu_attention/bench_jagged_causal_local.sh index aacd8137bc..cfee8af331 100644 --- a/example/ck_tile/18_hstu_attention/bench_jagged_causal_local.sh +++ b/example/ck_tile/18_hstu_attention/bench_jagged_causal_local.sh @@ -11,17 +11,16 @@ num_head=4 window_size=5 target=20 -add_target() +add_comma() { x=$* y="" for len in $x; do - new_len=$(($len + $target)); if test -z $y; then - y="$new_len" + y="$len" else - y="$y,$new_len" + y="$y,$len" fi; done @@ -35,12 +34,12 @@ sl8192="4571 3202 270 1540 8169 3365 6055 7181 2942 4213 2717 3593 7748 4646 550 sl16384="6956 7177 338 13755 10382 13392 10150 15592 15929 5256 6825 3804 5197 13415 14099 12418 13772 13659 5998 3715 9862 9183 11826 12964 6041 6712 12846 475 4672 7690 12280 10175" sl32768="28810 1574 80 24581 32298 19576 8028 25764 16544 14321 22771 7622 21090 27370 15921 5841 5458 23228 23619 17897 11996 31636 23183 20444 26332 7742 3418 9181 4750 18744 5201 2019" -s_sl1024=`add_target $sl1024` -s_sl2048=`add_target $sl2048` -s_sl4096=`add_target $sl4096` -s_sl8192=`add_target $sl8192` -s_sl16384=`add_target $sl16384` -s_sl32768=`add_target $sl32768` +s_sl1024=`add_comma $sl1024` +s_sl2048=`add_comma $sl2048` +s_sl4096=`add_comma $sl4096` +s_sl8192=`add_comma $sl8192` +s_sl16384=`add_comma $sl16384` +s_sl32768=`add_comma $sl32768` set -x diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index 6dc5bb8443..ea95f8c1a3 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -77,7 +77,7 @@ auto create_args(int argc, char* argv[]) .insert("nhead", "4", "number of heads") .insert("hdim_qk", "64", "headdim size of Q/K") .insert("hdim_v", "64", "headdim size of V/O") - .insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor") + .insert("seqlen", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len") .insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention") .insert("causal", "1", "enable causal mask or not") .insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable") @@ -206,25 +206,29 @@ bool run(const ck_tile::ArgParser& arg_parser) std::string str_of_lengths = arg_parser.get_str("seqlen"); std::vector seq_lengths = get_integers_from_string(str_of_lengths); - std::vector seq_offsets; + int uih_seqlen = 0; // means total seq lengths for jagged + int max_uih_seqlen = 0; + int max_target = 0; - int seqlen = 0; // means total seq lengths for jagged - int max_seqlen = 0; - - // supplement the sequence of lengths to have num_batch values - if(!num_targets.empty() && static_cast(num_targets.size()) < num_batch) + if(!num_targets.empty()) { - auto last_val = num_targets.back(); + // supplement num_targets using the last input value if user-provided lengths not enough + if(static_cast(num_targets.size()) < num_batch) + { + auto last_val = num_targets.back(); - for(int i = num_targets.size(); i < num_batch; i++) - num_targets.push_back(last_val); + for(int i = num_targets.size(); i < num_batch; i++) + num_targets.push_back(last_val); + }; + + // only consider num_batch values even if more values are provided by the user + for(int i = 0; i < num_batch; i++) + max_target = max(max_target, num_targets[i]); }; if(is_jagged) { - assert(num_batch >= seq_lengths.size()); - - // supplement the sequence of lengths to have num_batch values + // supplement seq_lengths using the last input value if user-provided lengths not enough if(static_cast(seq_lengths.size()) < num_batch) { auto last_len = seq_lengths.back(); @@ -233,54 +237,49 @@ bool run(const ck_tile::ArgParser& arg_parser) seq_lengths.push_back(last_len); }; - seq_offsets.push_back(0); - for(size_t i = 0; i < seq_lengths.size(); i++) + // only consider num_batch values even if more values are provided by the user + for(int i = 0; i < num_batch; i++) { - max_seqlen = max(max_seqlen, seq_lengths[i]); - seqlen += seq_lengths[i]; - seq_offsets.push_back(seqlen); + max_uih_seqlen = max(max_uih_seqlen, seq_lengths[i]); }; - if(!num_targets.empty()) + // only consider num_batch values even if more values are provided by the user + for(int i = 0; i < num_batch; i++) { - assert(num_batch == num_targets.size()); - - for(size_t i = 0; i < seq_lengths.size(); i++) - { - assert(seq_lengths[i] - num_targets[i] >= min_full_attn_seqlen); - assert(seq_lengths[i] - num_targets[i] >= contextual_seqlen); - }; - } - else - { - for(size_t i = 0; i < seq_lengths.size(); i++) - { - assert(seq_lengths[i] >= min_full_attn_seqlen); - assert(seq_lengths[i] >= contextual_seqlen); - }; + assert(seq_lengths[i] >= min_full_attn_seqlen); }; } else { assert(1 == seq_lengths.size()); - seqlen = seq_lengths[0]; - max_seqlen = seqlen; + uih_seqlen = seq_lengths[0]; + max_uih_seqlen = uih_seqlen; - if(!num_targets.empty()) - { - assert(num_batch == num_targets.size()); + assert(uih_seqlen >= min_full_attn_seqlen); + }; - for(size_t i = 0; i < seq_lengths.size(); i++) - { - assert(seqlen - num_targets[i] >= min_full_attn_seqlen); - assert(seqlen - num_targets[i] >= contextual_seqlen); - }; - } - else + int phy_seqlen = 0; + int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen; + + std::vector seq_offsets; + + if(is_jagged) + { + seq_offsets.push_back(0); + + for(int i = 0; i < num_batch; i++) { - assert(seqlen >= min_full_attn_seqlen); - assert(seqlen >= contextual_seqlen); + int batch_seqlen = num_targets.empty() + ? seq_lengths[i] + contextual_seqlen + : seq_lengths[i] + num_targets[i] + contextual_seqlen; + + phy_seqlen += batch_seqlen; + seq_offsets.push_back(phy_seqlen); }; + } + else + { + phy_seqlen = max_seqlen; }; long total_flops = 0; @@ -288,31 +287,34 @@ bool run(const ck_tile::ArgParser& arg_parser) // estimate the total flops occurred, ignoring the scaling and SILu if(is_jagged) { - for(auto len : seq_lengths) + for(int i = 0; i < num_batch; i++) + { + int len = seq_offsets[i + 1] - seq_offsets[i]; total_flops += (static_cast(len) * len * hdim_qk + static_cast(len) * hdim_v * len) * 2; + }; total_flops *= num_head; } else { total_flops = static_cast(num_batch) * num_head * - (static_cast(seqlen) * seqlen * hdim_qk + - static_cast(seqlen) * hdim_v * seqlen) * + (static_cast(phy_seqlen) * phy_seqlen * hdim_qk + + static_cast(phy_seqlen) * hdim_v * phy_seqlen) * 2; }; int batches_for_alloc = is_jagged ? 1 : num_batch; ck_tile::HostTensor q_host( - std::array{batches_for_alloc, seqlen, num_head, hdim_qk}); + std::array{batches_for_alloc, phy_seqlen, num_head, hdim_qk}); ck_tile::HostTensor k_host( - std::array{batches_for_alloc, seqlen, num_head, hdim_qk}); + std::array{batches_for_alloc, phy_seqlen, num_head, hdim_qk}); ck_tile::HostTensor v_host( - std::array{batches_for_alloc, seqlen, num_head, hdim_v}); + std::array{batches_for_alloc, phy_seqlen, num_head, hdim_v}); ck_tile::HostTensor o_host_ref( - std::array{batches_for_alloc, seqlen, num_head, hdim_v}); + std::array{batches_for_alloc, phy_seqlen, num_head, hdim_v}); ck_tile::HostTensor mask_host( save_mask ? std::array{num_batch, num_head, max_seqlen, max_seqlen} @@ -379,7 +381,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { params.is_jagged = false; params.num_batch = num_batch; - params.seqlen = seqlen; + params.seqlen = max_seqlen; params.q_ptr = q_dev.GetDeviceBuffer(); params.k_ptr = k_dev.GetDeviceBuffer(); params.v_ptr = v_dev.GetDeviceBuffer(); @@ -458,7 +460,7 @@ bool run(const ck_tile::ArgParser& arg_parser) mask_host, num_batch, 1.0f / std::sqrt(params.hdim_qk), - is_jagged ? max_seqlen : seqlen, + max_seqlen, seq_offsets, num_targets, window_size, @@ -467,7 +469,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }); ck_tile::HostTensor o_host( - std::array{batches_for_alloc, seqlen, num_head, hdim_v}); + std::array{batches_for_alloc, phy_seqlen, num_head, hdim_v}); o_dev.FromDevice(o_host.data());