From 52eff34d215c532bfc5817c96c90fda57bc473bf Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 27 Jun 2026 15:48:54 +0000 Subject: [PATCH] Add scripts for testing hstu attention fwd with drop-out --- ...tu_attention_hdim96_hdim64_with_dropout.sh | 58 ++++++++++ .../test_hstu_attention_with_dropout.sh | 28 ++--- ...jagged_causal_mattn0_full0_with_dropout.sh | 104 ++++++++++++++++++ ...ed_causal_mattn256_full256_with_dropout.sh | 104 ++++++++++++++++++ 4 files changed, 280 insertions(+), 14 deletions(-) create mode 100644 example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_hdim96_hdim64_with_dropout.sh create mode 100644 example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn0_full0_with_dropout.sh create mode 100644 example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full256_with_dropout.sh diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_hdim96_hdim64_with_dropout.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_hdim96_hdim64_with_dropout.sh new file mode 100644 index 0000000000..a971f32413 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_hdim96_hdim64_with_dropout.sh @@ -0,0 +1,58 @@ +#!/bin/bash +## This script can be used the verifying the using of WarpGemm 32x32x16 which is used by hdim64 + softmax + +BUILD=build +EXE="$BUILD/bin/tile_example_hstu_attention -softmax=0 -p_drop=0.2" + +attn_scale=1.0 +ndist=1 + +dtype="fp16" + +for hdim in 96 64; do + set -x + + ## no masking batched + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## no masking jagged + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal+local + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal+local+context + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+context + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal+local+context+target + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+context+target + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged no-causal+local+context+target + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+context+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged no-causal+local+context+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + set +x +done diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_with_dropout.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_with_dropout.sh index eef7fe6b8f..f41834d526 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_with_dropout.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_with_dropout.sh @@ -18,45 +18,45 @@ for dtype in "fp16" "bf16"; do set -x ## no masking batched - $EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist ## no masking jagged - $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist ## batched causal - $EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist ## jagged causal - $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist ## batched causal+local - $EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist ## jagged causal+local - $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist ## batched causal+local+context - $EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist ## jagged causal+local+context - $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist ## batched causal+local+context+target - $EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist ## jagged causal+local+context+target - $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist ## jagged no-causal+local+context+target - $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist ## jagged causal+local+target (minfull_len > max_uih_len) - $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist ## jagged causal+local+context+target (minfull_len > max_uih_len) - $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist ## jagged no-causal+local+context+target (minfull_len > max_uih_len) - $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + $EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist set +x done diff --git a/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn0_full0_with_dropout.sh b/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn0_full0_with_dropout.sh new file mode 100644 index 0000000000..dbefca089e --- /dev/null +++ b/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn0_full0_with_dropout.sh @@ -0,0 +1,104 @@ +#!/bin/bash + +set +x + +ndist=0 +BUILD=build + +USE_SOFTMAX=0 +if [ $# -ge 1 ]; then + USE_SOFTMAX=$1 +fi + +Training=${TEST_HSTU_FWD_TRAINING:-0} + +if [ $USE_SOFTMAX -eq 1 ]; then + EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training -p_drop=0.2" +else + EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training -p_drop=0.2" +fi + +dtype="bf16" + +set -x + +target8="10,10,14,17,16,12,14,9" + +## seqlen 1024 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=1004 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 2048 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=2028 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 3072 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=3052 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 4096 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=4076 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 8192 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=8172 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 16384 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=16364 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +target16="13,17,16,13,7,14,3,18,15,15,1,9,18,18,7,10" + +## seqlen 1024 +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=1004 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 2048 +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=2028 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 3072 +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=3052 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 4096 +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=4076 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 8192 +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=8172 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 16384 +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=16364 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +target32="13,17,16,13,7,14,3,18,15,15,1,9,18,18,7,10,11,0,4,8,2,10,20,14,11,7,4,6,9,7,14,17" + +## seqlen 1024 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=1004 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 2048 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=2028 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 3072 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=3052 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 4096 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=4076 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 8192 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=8172 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 16384 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=16364 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +set +x + diff --git a/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full256_with_dropout.sh b/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full256_with_dropout.sh new file mode 100644 index 0000000000..979a3b0af9 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full256_with_dropout.sh @@ -0,0 +1,104 @@ +#!/bin/bash + +set +x + +ndist=0 +BUILD=build + +USE_SOFTMAX=0 +if [ $# -ge 1 ]; then + USE_SOFTMAX=$1 +fi + +Training=${TEST_HSTU_FWD_TRAINING:-0} + +if [ $USE_SOFTMAX -eq 1 ]; then + EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training -p_drop=0.2" +else + EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training -p_drop=0.2" +fi + +dtype="bf16" + +set -x + +target8="10,10,14,17,16,12,14,9" + +## seqlen 1024 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=1004 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 2048 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=2028 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 3072 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=3052 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 4096 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=4076 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 8192 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=8172 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 16384 +$EXE -v=1 -prec=$dtype -b=8 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=16364 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target8 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +target16="13,17,16,13,7,14,3,18,15,15,1,9,18,18,7,10" + +## seqlen 1024 +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=1004 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 2048 +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=2028 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 3072orm +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=3052 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 4096 +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=4076 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 8192 +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=8172 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 16384 +$EXE -v=1 -prec=$dtype -b=16 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=16364 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target16 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +target32="13,17,16,13,7,14,3,18,15,15,1,9,18,18,7,10,11,0,4,8,2,10,20,14,11,7,4,6,9,7,14,17" + +## seqlen 1024 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=1004 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 2048 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=2028 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 3072 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=3052 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 4096 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=4076 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 8192 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=8172 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +## seqlen 16384 +$EXE -v=1 -prec=$dtype -b=32 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=16364 -causal=1 -local_len=256 -context_len=0 -minfull_len=256 -targets=$target32 -max_target=20 -alpha=2.0 -norm_dist=$ndist +echo -e "" + +set +x +