Fix wrong boundaries

This commit is contained in:
PoYen, Chen
2024-07-15 01:42:53 +00:00
parent 4e01307e04
commit 65dac9fb90

View File

@@ -47,8 +47,8 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
{
assert(0 < count);
seqlen_min = (0 < seqlen_min ? seqlen_min : seqlen_avg);
seqlen_max = (0 < seqlen_max ? seqlen_max : seqlen_avg);
seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
assert(seqlen_min <= seqlen_max);
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
@@ -116,16 +116,20 @@ decode_seqlen(mode_enum mode,
{
ck_tile::index_t q = _S2I_(q_val);
ck_tile::index_t k = _S2I_(k_val);
if(k < seqlen_k_min)
{
std::ostringstream msg;
msg << "seqlen_k (=" << k << ") is less than minimum seqlen_k (=" << seqlen_k_min
<< ")";
throw std::runtime_error(msg.str());
}
auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = std::vector<ck_tile::index_t>(batch, k < 0 ? q : k);
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_kpad);
}
else
@@ -150,17 +154,19 @@ decode_seqlen(mode_enum mode,
ck_tile::index_t kp = _S2I_(k_pad_val.substr(
pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp));
if(k < seqlen_k_min)
{
std::ostringstream msg;
msg << "seqlen_k (=" << k << ") is less than minimum seqlen_k (=" << seqlen_k_min
<< ")";
throw std::runtime_error(msg.str());
}
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
idx++;
if(found_q == std::string::npos || idx >= batch)
{