From 65dac9fb9065d0dbd4904c10d8e264ae6d09be95 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 15 Jul 2024 01:42:53 +0000 Subject: [PATCH] Fix wrong boundaries --- example/ck_tile/01_fmha/utils.hpp | 40 ++++++++++++++++++------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index abe6840c67..4af4e6959e 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -47,8 +47,8 @@ std::vector generate_seqlens(mode_enum mode, { assert(0 < count); - seqlen_min = (0 < seqlen_min ? seqlen_min : seqlen_avg); - seqlen_max = (0 < seqlen_max ? seqlen_max : seqlen_avg); + seqlen_min = (0 < seqlen_min ? seqlen_min : 1); + seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits::max()); assert(seqlen_min <= seqlen_max); std::vector seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max)); @@ -116,16 +116,20 @@ decode_seqlen(mode_enum mode, { ck_tile::index_t q = _S2I_(q_val); ck_tile::index_t k = _S2I_(k_val); - if(k < seqlen_k_min) - { - std::ostringstream msg; - msg << "seqlen_k (=" << k << ") is less than minimum seqlen_k (=" << seqlen_k_min - << ")"; - throw std::runtime_error(msg.str()); - } + auto s_q = std::vector(batch, q); auto s_k = std::vector(batch, k < 0 ? q : k); auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding + + // s_k should be greater than or equal to seqlen_k_min if provided + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + return std::make_tuple(s_q, s_k, s_kpad); } else @@ -150,17 +154,19 @@ decode_seqlen(mode_enum mode, ck_tile::index_t kp = _S2I_(k_pad_val.substr( pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp)); - if(k < seqlen_k_min) - { - std::ostringstream msg; - msg << "seqlen_k (=" << k << ") is less than minimum seqlen_k (=" << seqlen_k_min - << ")"; - throw std::runtime_error(msg.str()); - } - s_q.push_back(q); s_k.push_back(k < 0 ? q : k); s_kpad.push_back(kp); + + // s_k should be greater than or equal to seqlen_k_min + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + idx++; if(found_q == std::string::npos || idx >= batch) {