mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Add benchmark_hstu_attention.sh
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_hstu_attention
|
||||
|
||||
for dtype in "fp16" "bf16"; do
|
||||
## jagged is true
|
||||
cmd="$EXE -v=0 -prec=$dtype -b=80 -jagged=1 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1"
|
||||
echo $cmd
|
||||
$EXE -v=0 -prec=$dtype -b=80 -jagged=1 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
|
||||
## jagged is false
|
||||
cmd="$EXE -v=0 -prec=$dtype -b=80 -jagged=0 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1"
|
||||
echo $cmd
|
||||
$EXE -v=0 -prec=$dtype -b=80 -jagged=0 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
done
|
||||
@@ -440,8 +440,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
o_dev.FromDevice(o_host.data());
|
||||
|
||||
dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size());
|
||||
dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
|
||||
// dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size());
|
||||
// dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
|
||||
|
||||
auto [rtol, atol] = get_elimit<InOutDataType>();
|
||||
|
||||
@@ -454,7 +454,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::gpu_timer timer{};
|
||||
|
||||
timer.start(stream);
|
||||
for(int i = 0; i < 20; i++)
|
||||
for(int i = 0; i < 10; i++)
|
||||
{
|
||||
if constexpr(std::is_same<InOutDataType, ck_tile::fp16_t>::value)
|
||||
{
|
||||
@@ -473,9 +473,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
timer.stop(stream);
|
||||
|
||||
auto ms = timer.duration() / 20.f;
|
||||
auto ms = timer.duration() / 10.f;
|
||||
|
||||
std::cout << "Average execution time of the hstu_attention operator is " << ms
|
||||
std::cout << "Average execution time of the hstu_attention operation is " << ms
|
||||
<< " milli-seconds" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_hstu_attention
|
||||
|
||||
## no masking batched
|
||||
bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## no masking jagged
|
||||
bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## batched causal
|
||||
bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## jagged causal
|
||||
bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## batched causal+local
|
||||
bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
|
||||
## jagged causal+local
|
||||
bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
|
||||
|
||||
Reference in New Issue
Block a user