mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Add example parameter max_seqlen and max_target
This commit is contained in:
@@ -44,8 +44,10 @@
|
||||
.insert("nhead", "4", "number of heads")
|
||||
.insert("hdim_qk", "64", "headdim size of Q/K")
|
||||
.insert("hdim_v", "64", "headdim size of V/O")
|
||||
.insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("seqlens", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
|
||||
.insert("causal", "1", "enable causal mask or not")
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
|
||||
@@ -12,9 +12,9 @@ for seqlen in 1024 2048 4096 8192 16384 32768; do
|
||||
set -x
|
||||
|
||||
## no causal
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=0 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$seqlen -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=0 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$seqlen -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -perf=1
|
||||
|
||||
## has causal
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=0 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$seqlen -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=0 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$seqlen -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -perf=1
|
||||
set +x
|
||||
done
|
||||
|
||||
@@ -40,29 +40,29 @@ s_sl16384=`add_comma $sl16384`
|
||||
|
||||
set -x
|
||||
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl1024 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl1024 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl1024 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl1024 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl2048 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl2048 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl2048 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl2048 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl4096 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl4096 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl4096 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl4096 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl8192 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl8192 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl8192 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl8192 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl16384 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl16384 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl16384 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl16384 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
|
||||
set +x
|
||||
|
||||
@@ -43,17 +43,17 @@ s_sl32768=`add_comma $sl32768`
|
||||
|
||||
set -x
|
||||
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl1024 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl1024 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl2048 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl2048 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl4096 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl4096 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl8192 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl8192 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl16384 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl16384 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
echo -e ""
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlen=$s_sl32768 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=$num_batch -jagged=1 -nhead=$num_head -hdim_qk=$hdim -hdim_v=$hdim -seqlens=$s_sl32768 -causal=1 -local_len=$window_size -context_len=0 -minfull_len=0 -targets=$target -perf=1
|
||||
|
||||
set +x
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ for dtype in "fp16" "bf16"; do
|
||||
set -x
|
||||
|
||||
## jagged is true
|
||||
$EXE -v=0 -prec=$dtype -b=512 -jagged=1 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlen=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=512 -jagged=1 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlens=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
|
||||
## jagged is false
|
||||
$EXE -v=0 -prec=$dtype -b=512 -jagged=0 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlen=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
$EXE -v=0 -prec=$dtype -b=512 -jagged=0 -nhead=2 -hdim_qk=128 -hdim_v=128 -seqlens=$seqlen -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -perf=1
|
||||
|
||||
set +x
|
||||
done
|
||||
|
||||
@@ -77,8 +77,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("nhead", "4", "number of heads")
|
||||
.insert("hdim_qk", "64", "headdim size of Q/K")
|
||||
.insert("hdim_v", "64", "headdim size of V/O")
|
||||
.insert("seqlen", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("seqlens", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
|
||||
.insert("causal", "1", "enable causal mask or not")
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
@@ -203,9 +205,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::string str_of_targets = arg_parser.get_str("targets");
|
||||
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
|
||||
|
||||
std::string str_of_lengths = arg_parser.get_str("seqlen");
|
||||
std::string str_of_lengths = arg_parser.get_str("seqlens");
|
||||
std::vector<int> seq_lengths = get_integers_from_string(str_of_lengths);
|
||||
|
||||
int input_max_uih_seqlen = arg_parser.get_int("max_seqlen");
|
||||
int input_max_target = arg_parser.get_int("max_target");
|
||||
|
||||
int uih_seqlen = 0; // means total seq lengths for jagged
|
||||
int max_uih_seqlen = 0;
|
||||
int max_target = 0;
|
||||
@@ -258,6 +263,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
assert(uih_seqlen >= min_full_attn_seqlen);
|
||||
};
|
||||
|
||||
// the user input of max_uih_seqlen can either be ignored or be bigger than all uih_seqlens
|
||||
// the user input of max_target can either be ignored or be bigger than all targets
|
||||
assert(input_max_uih_seqlen <= 0 || input_max_uih_seqlen >= max_uih_seqlen);
|
||||
assert(input_max_target <= 0 || input_max_target >= max_target);
|
||||
|
||||
max_uih_seqlen = (input_max_uih_seqlen > 0) ? input_max_uih_seqlen : max_uih_seqlen;
|
||||
max_target = (input_max_target > 0) ? input_max_target : max_target;
|
||||
|
||||
int phy_seqlen = 0;
|
||||
int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen;
|
||||
|
||||
|
||||
@@ -3,4 +3,4 @@
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_hstu_attention
|
||||
|
||||
$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlen=56,60,64 -causal=1 -local_len=4 -context_len=3 -minfull_len=0 -targets=4,5,6 -save_mask=1
|
||||
$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=56,60,64 -causal=1 -local_len=4 -context_len=3 -minfull_len=0 -targets=4,5,6 -save_mask=1
|
||||
|
||||
@@ -7,34 +7,34 @@ for dtype in "fp16" "bf16"; do
|
||||
set -x
|
||||
|
||||
## no masking batched
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## no masking jagged
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## batched causal
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## jagged causal
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## batched causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## jagged causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0
|
||||
|
||||
## batched causal+local+context
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
|
||||
|
||||
## jagged causal+local+context
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0
|
||||
|
||||
## batched causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
|
||||
## jagged causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8
|
||||
|
||||
set +x
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user