From 94696543bc92c0f8c6c69ea5530c50223fb33322 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 29 Nov 2023 01:35:32 -0500 Subject: [PATCH] Model: Warn user if context > max_seq_len Unlike other backends, tabby attempts to generate even if the context is greater than the max sequence length via truncation of the given context. Rather than artifically erroring out, give a warning that outputted console metrics are going to be incorrect and to make sure that context <= max_seq_len. Signed-off-by: kingbri --- model.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/model.py b/model.py index 4650e7f..756908c 100644 --- a/model.py +++ b/model.py @@ -336,6 +336,13 @@ class ModelContainer: add_bos=kwargs.get("add_bos_token", True), encode_special_tokens = True ) + context_len = len(ids[0]) + + if context_len > self.config.max_seq_len: + print( + f"WARNING: The context length {context_len} is greater than the max_seq_len {self.config.max_seq_len}.", + "Generation is truncated and metrics may not be accurate." + ) # Begin @@ -392,14 +399,18 @@ class ModelContainer: elapsed_time = last_chunk_time - start_time initial_response = f"Response: {round(generated_tokens, 2)} tokens generated in {round(elapsed_time, 2)} seconds" - extra_responses = [] + itemization = [] + extra_parts = [] # Add tokens per second - extra_responses.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s") + itemization.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s") # Add context (original token count) if ids is not None: - extra_responses.append(f"context {len(ids[0])} tokens") + itemization.append(f"context {context_len} tokens") + + if context_len > self.config.max_seq_len: + extra_parts.append("<-- Not accurate (truncated)") # Print output - print(initial_response + " (" + ", ".join(extra_responses) + ")") + print(initial_response + " (" + ", ".join(itemization) + ") " + " ".join(extra_parts))