Fix in calculation of total_flops and update benchmark scripts

This commit is contained in:
Qianfeng Zhang
2025-04-13 08:50:00 +00:00
parent 71697d9cb9
commit 53e567977e
2 changed files with 11 additions and 8 deletions

View File

@@ -4,13 +4,15 @@ BUILD=build
EXE=$BUILD/bin/tile_example_hstu_attention
for dtype in "fp16" "bf16"; do
set -x
for seqlen in 512 1024 3072; do
set -x
## jagged is true
$EXE -v=0 -prec=$dtype -b=80 -jagged=1 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
## jagged is true
$EXE -v=0 -prec=$dtype -b=512 -jagged=1 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlen=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
## jagged is false
$EXE -v=0 -prec=$dtype -b=80 -jagged=0 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
## jagged is false
$EXE -v=0 -prec=$dtype -b=512 -jagged=0 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlen=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
set +x
set +x
done
done

View File

@@ -285,13 +285,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(is_jagged)
{
for(auto len : seq_lengths)
total_flops += 2 * (len * len * hdim_qk + len * hdim_v * len);
total_flops += (static_cast<long>(len) * len * hdim_qk + len * hdim_v * len) * 2;
total_flops *= num_head;
}
else
{
total_flops = num_batch * num_head * (seqlen * seqlen * hdim_qk + seqlen * hdim_v * seqlen);
total_flops = static_cast<long>(num_batch) * num_head *
(seqlen * seqlen * hdim_qk + seqlen * hdim_v * seqlen) * 2;
};
int batches_for_alloc = is_jagged ? 1 : num_batch;