Model: Warn user if context > max_seq_len

Unlike other backends, tabby attempts to generate even if the context
is greater than the max sequence length via truncation of the given
context.

Rather than artifically erroring out, give a warning that outputted
console metrics are going to be incorrect and to make sure that
context <= max_seq_len.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-29 01:35:32 -05:00
parent cad144126f
commit 94696543bc

View File

@@ -336,6 +336,13 @@ class ModelContainer:
add_bos=kwargs.get("add_bos_token", True),
encode_special_tokens = True
)
context_len = len(ids[0])
if context_len > self.config.max_seq_len:
print(
f"WARNING: The context length {context_len} is greater than the max_seq_len {self.config.max_seq_len}.",
"Generation is truncated and metrics may not be accurate."
)
# Begin
@@ -392,14 +399,18 @@ class ModelContainer:
elapsed_time = last_chunk_time - start_time
initial_response = f"Response: {round(generated_tokens, 2)} tokens generated in {round(elapsed_time, 2)} seconds"
extra_responses = []
itemization = []
extra_parts = []
# Add tokens per second
extra_responses.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s")
itemization.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s")
# Add context (original token count)
if ids is not None:
extra_responses.append(f"context {len(ids[0])} tokens")
itemization.append(f"context {context_len} tokens")
if context_len > self.config.max_seq_len:
extra_parts.append("<-- Not accurate (truncated)")
# Print output
print(initial_response + " (" + ", ".join(extra_responses) + ")")
print(initial_response + " (" + ", ".join(itemization) + ") " + " ".join(extra_parts))