diff --git a/example/ck_tile/18_hstu_attention/benchmark_hstu_attention.sh b/example/ck_tile/18_hstu_attention/benchmark_hstu_attention.sh index 0c08fdf52a..c02b13d4f4 100644 --- a/example/ck_tile/18_hstu_attention/benchmark_hstu_attention.sh +++ b/example/ck_tile/18_hstu_attention/benchmark_hstu_attention.sh @@ -4,13 +4,15 @@ BUILD=build EXE=$BUILD/bin/tile_example_hstu_attention for dtype in "fp16" "bf16"; do - set -x + for seqlen in 512 1024 3072; do + set -x - ## jagged is true - $EXE -v=0 -prec=$dtype -b=80 -jagged=1 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1 + ## jagged is true + $EXE -v=0 -prec=$dtype -b=512 -jagged=1 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlen=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1 - ## jagged is false - $EXE -v=0 -prec=$dtype -b=80 -jagged=0 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1 + ## jagged is false + $EXE -v=0 -prec=$dtype -b=512 -jagged=0 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlen=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1 - set +x + set +x + done done diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index bafeef361e..b33a7f8dee 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -285,13 +285,14 @@ bool run(const ck_tile::ArgParser& arg_parser) if(is_jagged) { for(auto len : seq_lengths) - total_flops += 2 * (len * len * hdim_qk + len * hdim_v * len); + total_flops += (static_cast(len) * len * hdim_qk + len * hdim_v * len) * 2; total_flops *= num_head; } else { - total_flops = num_batch * num_head * (seqlen * seqlen * hdim_qk + seqlen * hdim_v * seqlen); + total_flops = static_cast(num_batch) * num_head * + (seqlen * seqlen * hdim_qk + seqlen * hdim_v * seqlen) * 2; }; int batches_for_alloc = is_jagged ? 1 : num_batch;