diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index 70a5844cde..996032a717 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -39,7 +39,8 @@ std::vector to_seqstarts(ck_tile::span seqlens) return seqstarts; } -std::vector generate_seqlens(unsigned count, +std::vector generate_seqlens(mode_enum mode, + unsigned count, int32_t seqlen_avg, int32_t seqlen_min = -1, // if not negative, clamp min int32_t seqlen_max = -1, // if not negative, clamp max @@ -53,7 +54,7 @@ std::vector generate_seqlens(unsigned count, std::vector seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max)); - if(1 < count) + if(mode == mode_enum::group && 1 < count) { using size_type = std::vector::size_type; @@ -67,7 +68,7 @@ std::vector generate_seqlens(unsigned count, for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) { const size_type to_decrease = next_idx(); - // make sure each elements of seqlens is always greater than seqlen_min + // make sure each elements of seqlens is in range [seqlen_min, seqlen_max] if(seqlens[to_decrease] == seqlen_min) { continue; @@ -88,6 +89,16 @@ std::vector generate_seqlens(unsigned count, return seqlens; } +std::vector generate_seqstarts(mode_enum mode, + unsigned count, + int32_t seqlen_avg, + int32_t seqlen_min = -1, + int32_t seqlen_max = -1, + std::optional seed = std::nullopt) +{ + return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed)); +} + // return random integer generated uniformly in range [low, high] template auto randint(Int low, Int high, std::optional seed = std::nullopt) @@ -220,9 +231,9 @@ decode_seqlen(mode_enum mode, } if(idx < batch) { - auto rem_q = generate_seqlens(batch - idx, s_q.back(), 1, s_kpad.back(), seed); + auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed); auto rem_k = - generate_seqlens(batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed); + generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed); s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());