mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Update to the scripts and error thresholds
This commit is contained in:
@@ -4,13 +4,13 @@ BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_hstu_attention
|
||||
|
||||
for dtype in "fp16" "bf16"; do
|
||||
set -x
|
||||
|
||||
## jagged is true
|
||||
cmd="$EXE -v=0 -prec=$dtype -b=80 -jagged=1 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1"
|
||||
echo $cmd
|
||||
$EXE -v=0 -prec=$dtype -b=80 -jagged=1 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
|
||||
## jagged is false
|
||||
cmd="$EXE -v=0 -prec=$dtype -b=80 -jagged=0 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1"
|
||||
echo $cmd
|
||||
$EXE -v=0 -prec=$dtype -b=80 -jagged=0 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
|
||||
set +x
|
||||
done
|
||||
|
||||
@@ -162,8 +162,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 2e-3;
|
||||
double atol = 2e-3;
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
@@ -171,8 +171,8 @@ auto get_elimit()
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
double rtol = 2e-2;
|
||||
double atol = 2e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
@@ -292,7 +292,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<InOutDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistribution<InOutDataType>{0.f, 1.f, seed}(v_host);
|
||||
|
||||
ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -3,21 +3,26 @@
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_hstu_attention
|
||||
|
||||
## no masking batched
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
for dtype in "fp16" "bf16"; do
|
||||
set -x
|
||||
|
||||
## no masking jagged
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
## no masking batched
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## batched causal
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
## no masking jagged
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## jagged causal
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
## batched causal
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## batched causal+local
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
## jagged causal
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## jagged causal+local
|
||||
$EXE -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
## batched causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
|
||||
## jagged causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
|
||||
set +x
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user