mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Fix in calculation of total_flops and update benchmark scripts
This commit is contained in:
@@ -4,13 +4,15 @@ BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_hstu_attention
|
||||
|
||||
for dtype in "fp16" "bf16"; do
|
||||
set -x
|
||||
for seqlen in 512 1024 3072; do
|
||||
set -x
|
||||
|
||||
## jagged is true
|
||||
$EXE -v=0 -prec=$dtype -b=80 -jagged=1 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
## jagged is true
|
||||
$EXE -v=0 -prec=$dtype -b=512 -jagged=1 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlen=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
|
||||
## jagged is false
|
||||
$EXE -v=0 -prec=$dtype -b=80 -jagged=0 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
## jagged is false
|
||||
$EXE -v=0 -prec=$dtype -b=512 -jagged=0 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlen=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
|
||||
set +x
|
||||
set +x
|
||||
done
|
||||
done
|
||||
|
||||
@@ -285,13 +285,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(is_jagged)
|
||||
{
|
||||
for(auto len : seq_lengths)
|
||||
total_flops += 2 * (len * len * hdim_qk + len * hdim_v * len);
|
||||
total_flops += (static_cast<long>(len) * len * hdim_qk + len * hdim_v * len) * 2;
|
||||
|
||||
total_flops *= num_head;
|
||||
}
|
||||
else
|
||||
{
|
||||
total_flops = num_batch * num_head * (seqlen * seqlen * hdim_qk + seqlen * hdim_v * seqlen);
|
||||
total_flops = static_cast<long>(num_batch) * num_head *
|
||||
(seqlen * seqlen * hdim_qk + seqlen * hdim_v * seqlen) * 2;
|
||||
};
|
||||
|
||||
int batches_for_alloc = is_jagged ? 1 : num_batch;
|
||||
|
||||
Reference in New Issue
Block a user