diff --git a/OAI/types/model.py b/OAI/types/model.py index 4f636e4..15cb760 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, Field from time import time from typing import List, Optional +from gen_logging import LogConfig class ModelCardParameters(BaseModel): max_seq_len: Optional[int] = 4096 @@ -14,6 +15,7 @@ class ModelCard(BaseModel): object: str = "model" created: int = Field(default_factory=lambda: int(time())) owned_by: str = "tabbyAPI" + logging: Optional[LogConfig] = None parameters: Optional[ModelCardParameters] = None class ModelList(BaseModel): diff --git a/config_sample.yml b/config_sample.yml index f69db6e..0e4ddd4 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -13,6 +13,14 @@ network: # The port to host on (default: 5000) port: 5000 +# Options for logging +logging: + # Enable prompt logging (default: False) + prompt: False + + # Enable generation parameter logging (default: False) + generation_params: False + # Options for model overrides and loading model: # Overrides the directory to look for models (default: models) diff --git a/gen_logging.py b/gen_logging.py new file mode 100644 index 0000000..ff18ced --- /dev/null +++ b/gen_logging.py @@ -0,0 +1,47 @@ +from typing import Dict +from pydantic import BaseModel + +# Logging preference config +class LogConfig(BaseModel): + prompt: bool = False + generation_params: bool = False + +# Global reference to logging preferences +config = LogConfig() + +# Wrapper to set the logging config for generations +def update_from_dict(options_dict: Dict[str, bool]): + global config + + # Force bools on the dict + for value in options_dict.values(): + if value is None: + value = False + + config = LogConfig.parse_obj(options_dict) + +def broadcast_status(): + enabled = [] + if config.prompt: + enabled.append("prompts") + + if config.generation_params: + enabled.append("generation params") + + if len(enabled) > 0: + print("Generation logging is enabled for: " + ", ".join(enabled)) + else: + print("Generation logging is disabled") + +# Logs generation parameters to console +def log_generation_params(**kwargs): + if config.generation_params: + print(f"Generation options: {kwargs}\n") + +def log_prompt(prompt: str): + if config.prompt: + print(f"Prompt: {prompt if prompt else 'Empty'}\n") + +def log_response(response: str): + if config.prompt: + print(f"Response: {response if response else 'Empty'}\n") diff --git a/main.py b/main.py index 22455c1..9f4871d 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import uvicorn import yaml import pathlib +import gen_logging from asyncio import CancelledError from auth import check_admin_key, check_api_key, load_auth_keys from fastapi import FastAPI, Request, HTTPException, Depends @@ -81,7 +82,8 @@ async def get_current_model(): rope_alpha = model_container.config.scale_alpha_value, max_seq_len = model_container.config.max_seq_len, prompt_template = unwrap(model_container.prompt_template, "auto") - ) + ), + logging = gen_logging.config ) if model_container.draft_config: @@ -370,6 +372,13 @@ if __name__ == "__main__": ) config = {} + # Override the generation log options if given + log_config = unwrap(config.get("logging"), {}) + if log_config: + gen_logging.update_from_dict(log_config) + + gen_logging.broadcast_status() + # If an initial model name is specified, create a container and load the model model_config = unwrap(config.get("model"), {}) if "model_name" in model_config: diff --git a/model.py b/model.py index d281dba..f34501d 100644 --- a/model.py +++ b/model.py @@ -14,6 +14,7 @@ from exllamav2.generator import( ) from typing import List, Optional, Union from utils import coalesce, unwrap +from gen_logging import log_generation_params, log_prompt, log_response # Bytes to reserve on first device when loading with auto split auto_split_reserve_bytes = 96 * 1024**2 @@ -351,13 +352,6 @@ class ModelContainer: stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) - - # Ban the EOS token if specified. If not, append to stop conditions as well. - if ban_eos_token: - gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) - else: - stop_conditions.append(self.tokenizer.eos_token_id) - # Override sampler settings for temp = 0 if gen_settings.temperature == 0: gen_settings.temperature = 1.0 @@ -365,7 +359,25 @@ class ModelContainer: gen_settings.top_p = 0 gen_settings.typical = 0 - # Stop conditions + # Log generation options to console + log_generation_params( + **vars(gen_settings), + token_healing = token_healing, + max_tokens = max_tokens, + stop_conditions = stop_conditions + ) + + # Log prompt to console + log_prompt(prompt) + + # Ban the EOS token if specified. If not, append to stop conditions as well. + # Set this below logging to avoid polluting the stop strings array + if ban_eos_token: + gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) + else: + stop_conditions.append(self.tokenizer.eos_token_id) + + # Stop conditions self.generator.set_stop_conditions(stop_conditions) # Tokenized context @@ -430,9 +442,12 @@ class ModelContainer: if eos or generated_tokens == max_tokens: break + # Print response + log_response(full_response) + elapsed_time = last_chunk_time - start_time - initial_response = f"Response: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds" + initial_response = f"Metrics: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds" itemization = [] extra_parts = []