mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Fix wrong boundaries
This commit is contained in:
@@ -47,8 +47,8 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
{
|
||||
assert(0 < count);
|
||||
|
||||
seqlen_min = (0 < seqlen_min ? seqlen_min : seqlen_avg);
|
||||
seqlen_max = (0 < seqlen_max ? seqlen_max : seqlen_avg);
|
||||
seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
|
||||
seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
|
||||
assert(seqlen_min <= seqlen_max);
|
||||
|
||||
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
|
||||
@@ -116,16 +116,20 @@ decode_seqlen(mode_enum mode,
|
||||
{
|
||||
ck_tile::index_t q = _S2I_(q_val);
|
||||
ck_tile::index_t k = _S2I_(k_val);
|
||||
if(k < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << "seqlen_k (=" << k << ") is less than minimum seqlen_k (=" << seqlen_k_min
|
||||
<< ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
auto s_q = std::vector<ck_tile::index_t>(batch, q);
|
||||
auto s_k = std::vector<ck_tile::index_t>(batch, k < 0 ? q : k);
|
||||
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
|
||||
|
||||
// s_k should be greater than or equal to seqlen_k_min if provided
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
}
|
||||
else
|
||||
@@ -150,17 +154,19 @@ decode_seqlen(mode_enum mode,
|
||||
ck_tile::index_t kp = _S2I_(k_pad_val.substr(
|
||||
pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp));
|
||||
|
||||
if(k < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << "seqlen_k (=" << k << ") is less than minimum seqlen_k (=" << seqlen_k_min
|
||||
<< ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
s_kpad.push_back(kp);
|
||||
|
||||
// s_k should be greater than or equal to seqlen_k_min
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
idx++;
|
||||
if(found_q == std::string::npos || idx >= batch)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user