mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Add missing function and parameters (#1493)
[ROCm/composable_kernel commit: 8107ee6270]
This commit is contained in:
@@ -39,7 +39,8 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
|
||||
return seqstarts;
|
||||
}
|
||||
|
||||
std::vector<int32_t> generate_seqlens(unsigned count,
|
||||
std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_min = -1, // if not negative, clamp min
|
||||
int32_t seqlen_max = -1, // if not negative, clamp max
|
||||
@@ -53,7 +54,7 @@ std::vector<int32_t> generate_seqlens(unsigned count,
|
||||
|
||||
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
|
||||
|
||||
if(1 < count)
|
||||
if(mode == mode_enum::group && 1 < count)
|
||||
{
|
||||
using size_type = std::vector<int32_t>::size_type;
|
||||
|
||||
@@ -67,7 +68,7 @@ std::vector<int32_t> generate_seqlens(unsigned count,
|
||||
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
|
||||
{
|
||||
const size_type to_decrease = next_idx();
|
||||
// make sure each elements of seqlens is always greater than seqlen_min
|
||||
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
|
||||
if(seqlens[to_decrease] == seqlen_min)
|
||||
{
|
||||
continue;
|
||||
@@ -88,6 +89,16 @@ std::vector<int32_t> generate_seqlens(unsigned count,
|
||||
return seqlens;
|
||||
}
|
||||
|
||||
std::vector<int32_t> generate_seqstarts(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_min = -1,
|
||||
int32_t seqlen_max = -1,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed));
|
||||
}
|
||||
|
||||
// return random integer generated uniformly in range [low, high]
|
||||
template <typename Int = int>
|
||||
auto randint(Int low, Int high, std::optional<unsigned> seed = std::nullopt)
|
||||
@@ -220,9 +231,9 @@ decode_seqlen(mode_enum mode,
|
||||
}
|
||||
if(idx < batch)
|
||||
{
|
||||
auto rem_q = generate_seqlens(batch - idx, s_q.back(), 1, s_kpad.back(), seed);
|
||||
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed);
|
||||
auto rem_k =
|
||||
generate_seqlens(batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed);
|
||||
generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed);
|
||||
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
|
||||
Reference in New Issue
Block a user