Add example parameter max_seqlen and max_target

This commit is contained in:
Qianfeng Zhang
2025-05-27 14:18:41 +00:00
parent c9e19351c7
commit 10c35125d2
8 changed files with 49 additions and 34 deletions

View File

@@ -44,8 +44,10 @@
.insert("nhead", "4", "number of heads")
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("seqlens", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")

View File

@@ -12,9 +12,9 @@ for seqlen in 1024 2048 4096 8192 16384 32768; do
set -x
## no causal
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=0 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$seqlen -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=0 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$seqlen -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -perf=1
## has causal
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=0 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$seqlen -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=0 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$seqlen -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -perf=1
set +x
done

View File

@@ -40,29 +40,29 @@ s_sl16384=`add_comma $sl16384`
set -x
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl1024 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl1024 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl1024 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl1024 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl2048 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl2048 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl2048 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl2048 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl4096 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl4096 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl4096 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl4096 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl8192 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl8192 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl8192 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl8192 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl16384 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl16384 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl16384 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl16384 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
set +x

View File

@@ -43,17 +43,17 @@ s_sl32768=`add_comma $sl32768`
set -x
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl1024 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl1024 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl2048 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl2048 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl4096 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl4096 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl8192 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl8192 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl16384 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl16384 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
echo -e ""
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl32768 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl32768 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
set +x

View File

@@ -8,10 +8,10 @@ for dtype in "fp16" "bf16"; do
set -x
## jagged is true
$EXE -v=0 -prec=$dtype -b=512 -jagged=1 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlen=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
$EXE -v=0 -prec=$dtype -b=512 -jagged=1 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlens=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
## jagged is false
$EXE -v=0 -prec=$dtype -b=512 -jagged=0 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlen=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
$EXE -v=0 -prec=$dtype -b=512 -jagged=0 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlens=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
set +x
done

View File

@@ -77,8 +77,10 @@ auto create_args(int argc, char* argv[])
.insert("nhead", "4", "number of heads")
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlen", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("seqlens", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
@@ -203,9 +205,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::string str_of_targets = arg_parser.get_str("targets");
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
std::string str_of_lengths = arg_parser.get_str("seqlen");
std::string str_of_lengths = arg_parser.get_str("seqlens");
std::vector<int> seq_lengths = get_integers_from_string(str_of_lengths);
int input_max_uih_seqlen = arg_parser.get_int("max_seqlen");
int input_max_target = arg_parser.get_int("max_target");
int uih_seqlen = 0; // means total seq lengths for jagged
int max_uih_seqlen = 0;
int max_target = 0;
@@ -258,6 +263,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
assert(uih_seqlen >= min_full_attn_seqlen);
};
// the user input of max_uih_seqlen can either be ignored or be bigger than all uih_seqlens
// the user input of max_target can either be ignored or be bigger than all targets
assert(input_max_uih_seqlen <= 0 || input_max_uih_seqlen >= max_uih_seqlen);
assert(input_max_target <= 0 || input_max_target >= max_target);
max_uih_seqlen = (input_max_uih_seqlen > 0) ? input_max_uih_seqlen : max_uih_seqlen;
max_target = (input_max_target > 0) ? input_max_target : max_target;
int phy_seqlen = 0;
int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen;

View File

@@ -3,4 +3,4 @@
BUILD=build
EXE=$BUILD/bin/tile_example_hstu_attention
$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlen=56,60,64 -causal=1 -local_len=4 -context_len=3 -minfull_len=0 -targets=4,5,6 -save_mask=1
$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=56,60,64 -causal=1 -local_len=4 -context_len=3 -minfull_len=0 -targets=4,5,6 -save_mask=1

View File

@@ -7,34 +7,34 @@ for dtype in "fp16" "bf16"; do
set -x
## no masking batched
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
## no masking jagged
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
## batched causal
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
## jagged causal
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
## batched causal+local
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
## jagged causal+local
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
## batched causal+local+context
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
## jagged causal+local+context
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
## batched causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
## jagged causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
set +x
done