From 083df7d5858c9c3768830c09b879d9f07091466a Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 12 Dec 2023 23:43:35 -0500 Subject: [PATCH] Tree: Add generation logging support Generations can be logged in the console along with sampling parameters if the user enables it in config. Metrics are always logged at the end of each prompt. In addition, the model endpoint tells the user if they're being logged or not for transparancy purposes. Signed-off-by: kingbri --- OAI/types/model.py | 2 ++ config_sample.yml | 8 ++++++++ gen_logging.py | 47 ++++++++++++++++++++++++++++++++++++++++++++++ main.py | 11 ++++++++++- model.py | 33 +++++++++++++++++++++++--------- 5 files changed, 91 insertions(+), 10 deletions(-) create mode 100644 gen_logging.py diff --git a/OAI/types/model.py b/OAI/types/model.py index 4f636e4..15cb760 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, Field from time import time from typing import List, Optional +from gen_logging import LogConfig class ModelCardParameters(BaseModel): max_seq_len: Optional[int] = 4096 @@ -14,6 +15,7 @@ class ModelCard(BaseModel): object: str = "model" created: int = Field(default_factory=lambda: int(time())) owned_by: str = "tabbyAPI" + logging: Optional[LogConfig] = None parameters: Optional[ModelCardParameters] = None class ModelList(BaseModel): diff --git a/config_sample.yml b/config_sample.yml index f69db6e..0e4ddd4 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -13,6 +13,14 @@ network: # The port to host on (default: 5000) port: 5000 +# Options for logging +logging: + # Enable prompt logging (default: False) + prompt: False + + # Enable generation parameter logging (default: False) + generation_params: False + # Options for model overrides and loading model: # Overrides the directory to look for models (default: models) diff --git a/gen_logging.py b/gen_logging.py new file mode 100644 index 0000000..ff18ced --- /dev/null +++ b/gen_logging.py @@ -0,0 +1,47 @@ +from typing import Dict +from pydantic import BaseModel + +# Logging preference config +class LogConfig(BaseModel): + prompt: bool = False + generation_params: bool = False + +# Global reference to logging preferences +config = LogConfig() + +# Wrapper to set the logging config for generations +def update_from_dict(options_dict: Dict[str, bool]): + global config + + # Force bools on the dict + for value in options_dict.values(): + if value is None: + value = False + + config = LogConfig.parse_obj(options_dict) + +def broadcast_status(): + enabled = [] + if config.prompt: + enabled.append("prompts") + + if config.generation_params: + enabled.append("generation params") + + if len(enabled) > 0: + print("Generation logging is enabled for: " + ", ".join(enabled)) + else: + print("Generation logging is disabled") + +# Logs generation parameters to console +def log_generation_params(**kwargs): + if config.generation_params: + print(f"Generation options: {kwargs}\n") + +def log_prompt(prompt: str): + if config.prompt: + print(f"Prompt: {prompt if prompt else 'Empty'}\n") + +def log_response(response: str): + if config.prompt: + print(f"Response: {response if response else 'Empty'}\n") diff --git a/main.py b/main.py index 22455c1..9f4871d 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import uvicorn import yaml import pathlib +import gen_logging from asyncio import CancelledError from auth import check_admin_key, check_api_key, load_auth_keys from fastapi import FastAPI, Request, HTTPException, Depends @@ -81,7 +82,8 @@ async def get_current_model(): rope_alpha = model_container.config.scale_alpha_value, max_seq_len = model_container.config.max_seq_len, prompt_template = unwrap(model_container.prompt_template, "auto") - ) + ), + logging = gen_logging.config ) if model_container.draft_config: @@ -370,6 +372,13 @@ if __name__ == "__main__": ) config = {} + # Override the generation log options if given + log_config = unwrap(config.get("logging"), {}) + if log_config: + gen_logging.update_from_dict(log_config) + + gen_logging.broadcast_status() + # If an initial model name is specified, create a container and load the model model_config = unwrap(config.get("model"), {}) if "model_name" in model_config: diff --git a/model.py b/model.py index d281dba..f34501d 100644 --- a/model.py +++ b/model.py @@ -14,6 +14,7 @@ from exllamav2.generator import( ) from typing import List, Optional, Union from utils import coalesce, unwrap +from gen_logging import log_generation_params, log_prompt, log_response # Bytes to reserve on first device when loading with auto split auto_split_reserve_bytes = 96 * 1024**2 @@ -351,13 +352,6 @@ class ModelContainer: stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) - - # Ban the EOS token if specified. If not, append to stop conditions as well. - if ban_eos_token: - gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) - else: - stop_conditions.append(self.tokenizer.eos_token_id) - # Override sampler settings for temp = 0 if gen_settings.temperature == 0: gen_settings.temperature = 1.0 @@ -365,7 +359,25 @@ class ModelContainer: gen_settings.top_p = 0 gen_settings.typical = 0 - # Stop conditions + # Log generation options to console + log_generation_params( + **vars(gen_settings), + token_healing = token_healing, + max_tokens = max_tokens, + stop_conditions = stop_conditions + ) + + # Log prompt to console + log_prompt(prompt) + + # Ban the EOS token if specified. If not, append to stop conditions as well. + # Set this below logging to avoid polluting the stop strings array + if ban_eos_token: + gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) + else: + stop_conditions.append(self.tokenizer.eos_token_id) + + # Stop conditions self.generator.set_stop_conditions(stop_conditions) # Tokenized context @@ -430,9 +442,12 @@ class ModelContainer: if eos or generated_tokens == max_tokens: break + # Print response + log_response(full_response) + elapsed_time = last_chunk_time - start_time - initial_response = f"Response: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds" + initial_response = f"Metrics: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds" itemization = [] extra_parts = []