diff --git a/model.py b/model.py index 4650e7f..756908c 100644 --- a/model.py +++ b/model.py @@ -336,6 +336,13 @@ class ModelContainer: add_bos=kwargs.get("add_bos_token", True), encode_special_tokens = True ) + context_len = len(ids[0]) + + if context_len > self.config.max_seq_len: + print( + f"WARNING: The context length {context_len} is greater than the max_seq_len {self.config.max_seq_len}.", + "Generation is truncated and metrics may not be accurate." + ) # Begin @@ -392,14 +399,18 @@ class ModelContainer: elapsed_time = last_chunk_time - start_time initial_response = f"Response: {round(generated_tokens, 2)} tokens generated in {round(elapsed_time, 2)} seconds" - extra_responses = [] + itemization = [] + extra_parts = [] # Add tokens per second - extra_responses.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s") + itemization.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s") # Add context (original token count) if ids is not None: - extra_responses.append(f"context {len(ids[0])} tokens") + itemization.append(f"context {context_len} tokens") + + if context_len > self.config.max_seq_len: + extra_parts.append("<-- Not accurate (truncated)") # Print output - print(initial_response + " (" + ", ".join(extra_responses) + ")") + print(initial_response + " (" + ", ".join(itemization) + ") " + " ".join(extra_parts))