Update to the method for calculating max_seqlen in the example

This commit is contained in:
Qianfeng Zhang
2025-05-27 10:36:43 +00:00
parent dc0977faad
commit c9e19351c7
4 changed files with 78 additions and 78 deletions

View File

@@ -44,7 +44,7 @@
.insert("nhead", "4", "number of heads")
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor")
.insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")

View File

@@ -10,17 +10,16 @@ num_batch=32
num_head=4
target=20
add_target()
add_comma()
{
x=$*
y=""
for len in $x; do
new_len=$(($len + $target));
if test -z $y; then
y="$new_len"
y="$len"
else
y="$y,$new_len"
y="$y,$len"
fi;
done
@@ -33,11 +32,11 @@ sl4096="1497 2516 3179 2891 190 3572 640 3025 464 1824 712 1519 2727 2621 1135 7
sl8192="4571 3202 270 1540 8169 3365 6055 7181 2942 4213 2717 3593 7748 4646 5502 4489 6525 2481 7397 2983 5667 1003 7926 3659 6129 6647 3758 6244 4175 2327 849 5261"
sl16384="6956 7177 338 13755 10382 13392 10150 15592 15929 5256 6825 3804 5197 13415 14099 12418 13772 13659 5998 3715 9862 9183 11826 12964 6041 6712 12846 475 4672 7690 12280 10175"
s_sl1024=`add_target $sl1024`
s_sl2048=`add_target $sl2048`
s_sl4096=`add_target $sl4096`
s_sl8192=`add_target $sl8192`
s_sl16384=`add_target $sl16384`
s_sl1024=`add_comma $sl1024`
s_sl2048=`add_comma $sl2048`
s_sl4096=`add_comma $sl4096`
s_sl8192=`add_comma $sl8192`
s_sl16384=`add_comma $sl16384`
set -x

View File

@@ -11,17 +11,16 @@ num_head=4
window_size=5
target=20
add_target()
add_comma()
{
x=$*
y=""
for len in $x; do
new_len=$(($len + $target));
if test -z $y; then
y="$new_len"
y="$len"
else
y="$y,$new_len"
y="$y,$len"
fi;
done
@@ -35,12 +34,12 @@ sl8192="4571 3202 270 1540 8169 3365 6055 7181 2942 4213 2717 3593 7748 4646 550
sl16384="6956 7177 338 13755 10382 13392 10150 15592 15929 5256 6825 3804 5197 13415 14099 12418 13772 13659 5998 3715 9862 9183 11826 12964 6041 6712 12846 475 4672 7690 12280 10175"
sl32768="28810 1574 80 24581 32298 19576 8028 25764 16544 14321 22771 7622 21090 27370 15921 5841 5458 23228 23619 17897 11996 31636 23183 20444 26332 7742 3418 9181 4750 18744 5201 2019"
s_sl1024=`add_target $sl1024`
s_sl2048=`add_target $sl2048`
s_sl4096=`add_target $sl4096`
s_sl8192=`add_target $sl8192`
s_sl16384=`add_target $sl16384`
s_sl32768=`add_target $sl32768`
s_sl1024=`add_comma $sl1024`
s_sl2048=`add_comma $sl2048`
s_sl4096=`add_comma $sl4096`
s_sl8192=`add_comma $sl8192`
s_sl16384=`add_comma $sl16384`
s_sl32768=`add_comma $sl32768`
set -x

View File

@@ -77,7 +77,7 @@ auto create_args(int argc, char* argv[])
.insert("nhead", "4", "number of heads")
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor")
.insert("seqlen", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
@@ -206,25 +206,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::string str_of_lengths = arg_parser.get_str("seqlen");
std::vector<int> seq_lengths = get_integers_from_string(str_of_lengths);
std::vector<int> seq_offsets;
int uih_seqlen = 0; // means total seq lengths for jagged
int max_uih_seqlen = 0;
int max_target = 0;
int seqlen = 0; // means total seq lengths for jagged
int max_seqlen = 0;
// supplement the sequence of lengths to have num_batch values
if(!num_targets.empty() && static_cast<int>(num_targets.size()) < num_batch)
if(!num_targets.empty())
{
auto last_val = num_targets.back();
// supplement num_targets using the last input value if user-provided lengths not enough
if(static_cast<int>(num_targets.size()) < num_batch)
{
auto last_val = num_targets.back();
for(int i = num_targets.size(); i < num_batch; i++)
num_targets.push_back(last_val);
for(int i = num_targets.size(); i < num_batch; i++)
num_targets.push_back(last_val);
};
// only consider num_batch values even if more values are provided by the user
for(int i = 0; i < num_batch; i++)
max_target = max(max_target, num_targets[i]);
};
if(is_jagged)
{
assert(num_batch >= seq_lengths.size());
// supplement the sequence of lengths to have num_batch values
// supplement seq_lengths using the last input value if user-provided lengths not enough
if(static_cast<int>(seq_lengths.size()) < num_batch)
{
auto last_len = seq_lengths.back();
@@ -233,54 +237,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
seq_lengths.push_back(last_len);
};
seq_offsets.push_back(0);
for(size_t i = 0; i < seq_lengths.size(); i++)
// only consider num_batch values even if more values are provided by the user
for(int i = 0; i < num_batch; i++)
{
max_seqlen = max(max_seqlen, seq_lengths[i]);
seqlen += seq_lengths[i];
seq_offsets.push_back(seqlen);
max_uih_seqlen = max(max_uih_seqlen, seq_lengths[i]);
};
if(!num_targets.empty())
// only consider num_batch values even if more values are provided by the user
for(int i = 0; i < num_batch; i++)
{
assert(num_batch == num_targets.size());
for(size_t i = 0; i < seq_lengths.size(); i++)
{
assert(seq_lengths[i] - num_targets[i] >= min_full_attn_seqlen);
assert(seq_lengths[i] - num_targets[i] >= contextual_seqlen);
};
}
else
{
for(size_t i = 0; i < seq_lengths.size(); i++)
{
assert(seq_lengths[i] >= min_full_attn_seqlen);
assert(seq_lengths[i] >= contextual_seqlen);
};
assert(seq_lengths[i] >= min_full_attn_seqlen);
};
}
else
{
assert(1 == seq_lengths.size());
seqlen = seq_lengths[0];
max_seqlen = seqlen;
uih_seqlen = seq_lengths[0];
max_uih_seqlen = uih_seqlen;
if(!num_targets.empty())
{
assert(num_batch == num_targets.size());
assert(uih_seqlen >= min_full_attn_seqlen);
};
for(size_t i = 0; i < seq_lengths.size(); i++)
{
assert(seqlen - num_targets[i] >= min_full_attn_seqlen);
assert(seqlen - num_targets[i] >= contextual_seqlen);
};
}
else
int phy_seqlen = 0;
int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen;
std::vector<int> seq_offsets;
if(is_jagged)
{
seq_offsets.push_back(0);
for(int i = 0; i < num_batch; i++)
{
assert(seqlen >= min_full_attn_seqlen);
assert(seqlen >= contextual_seqlen);
int batch_seqlen = num_targets.empty()
? seq_lengths[i] + contextual_seqlen
: seq_lengths[i] + num_targets[i] + contextual_seqlen;
phy_seqlen += batch_seqlen;
seq_offsets.push_back(phy_seqlen);
};
}
else
{
phy_seqlen = max_seqlen;
};
long total_flops = 0;
@@ -288,31 +287,34 @@ bool run(const ck_tile::ArgParser& arg_parser)
// estimate the total flops occurred, ignoring the scaling and SILu
if(is_jagged)
{
for(auto len : seq_lengths)
for(int i = 0; i < num_batch; i++)
{
int len = seq_offsets[i + 1] - seq_offsets[i];
total_flops +=
(static_cast<long>(len) * len * hdim_qk + static_cast<long>(len) * hdim_v * len) *
2;
};
total_flops *= num_head;
}
else
{
total_flops = static_cast<long>(num_batch) * num_head *
(static_cast<long>(seqlen) * seqlen * hdim_qk +
static_cast<long>(seqlen) * hdim_v * seqlen) *
(static_cast<long>(phy_seqlen) * phy_seqlen * hdim_qk +
static_cast<long>(phy_seqlen) * hdim_v * phy_seqlen) *
2;
};
int batches_for_alloc = is_jagged ? 1 : num_batch;
ck_tile::HostTensor<InOutDataType> q_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_qk});
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_qk});
ck_tile::HostTensor<InOutDataType> k_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_qk});
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_qk});
ck_tile::HostTensor<InOutDataType> v_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_v});
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_v});
ck_tile::HostTensor<InOutDataType> o_host_ref(
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_v});
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_v});
ck_tile::HostTensor<int8_t> mask_host(
save_mask ? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_seqlen, max_seqlen}
@@ -379,7 +381,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
params.is_jagged = false;
params.num_batch = num_batch;
params.seqlen = seqlen;
params.seqlen = max_seqlen;
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
@@ -458,7 +460,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask_host,
num_batch,
1.0f / std::sqrt(params.hdim_qk),
is_jagged ? max_seqlen : seqlen,
max_seqlen,
seq_offsets,
num_targets,
window_size,
@@ -467,7 +469,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
});
ck_tile::HostTensor<InOutDataType> o_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_v});
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen, num_head, hdim_v});
o_dev.FromDevice(o_host.data());