From 1766e6d3be592d2068f55bbedd556ea342def3a0 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Wed, 9 Apr 2025 10:34:37 +0000 Subject: [PATCH] Update to the scripts and error thresholds --- .../benchmark_hstu_attention.sh | 8 ++--- .../example_hstu_attention.cpp | 10 +++---- .../18_hstu_attention/test_hstu_attention.sh | 29 +++++++++++-------- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/benchmark_hstu_attention.sh b/example/ck_tile/18_hstu_attention/benchmark_hstu_attention.sh index 0c2c10af50..0c08fdf52a 100644 --- a/example/ck_tile/18_hstu_attention/benchmark_hstu_attention.sh +++ b/example/ck_tile/18_hstu_attention/benchmark_hstu_attention.sh @@ -4,13 +4,13 @@ BUILD=build EXE=$BUILD/bin/tile_example_hstu_attention for dtype in "fp16" "bf16"; do + set -x + ## jagged is true - cmd="$EXE -v=0 -prec=$dtype -b=80 -jagged=1 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1" - echo $cmd $EXE -v=0 -prec=$dtype -b=80 -jagged=1 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1 ## jagged is false - cmd="$EXE -v=0 -prec=$dtype -b=80 -jagged=0 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1" - echo $cmd $EXE -v=0 -prec=$dtype -b=80 -jagged=0 -nhead=8 -hdim_qk=128 -hdim_v=128 -seqlen=1000 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1 + + set +x done diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index afe3951ca4..bb6bb9050f 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -162,8 +162,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara template auto get_elimit() { - double rtol = 2e-3; - double atol = 2e-3; + double rtol = 1e-2; + double atol = 1e-2; return ck_tile::make_tuple(rtol, atol); } @@ -171,8 +171,8 @@ auto get_elimit() template <> auto get_elimit() { - double rtol = 1e-2; - double atol = 1e-2; + double rtol = 2e-2; + double atol = 2e-2; return ck_tile::make_tuple(rtol, atol); } @@ -292,7 +292,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillNormalDistribution{0.f, 1.f, seed}(q_host); ck_tile::FillNormalDistribution{0.f, 1.f, seed}(k_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(v_host); ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes()); diff --git a/example/ck_tile/18_hstu_attention/test_hstu_attention.sh b/example/ck_tile/18_hstu_attention/test_hstu_attention.sh index 1b42f60efc..88b8f812a3 100644 --- a/example/ck_tile/18_hstu_attention/test_hstu_attention.sh +++ b/example/ck_tile/18_hstu_attention/test_hstu_attention.sh @@ -3,21 +3,26 @@ BUILD=build EXE=$BUILD/bin/tile_example_hstu_attention -## no masking batched -$EXE -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 +for dtype in "fp16" "bf16"; do + set -x -## no masking jagged -$EXE -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 + ## no masking batched + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -## batched causal -$EXE -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 + ## no masking jagged + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -## jagged causal -$EXE -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 + ## batched causal + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -## batched causal+local -$EXE -v=1 -prec=fp16 -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 + ## jagged causal + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -## jagged causal+local -$EXE -v=1 -prec=fp16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 + ## batched causal+local + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 + ## jagged causal+local + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 + + set +x +done