Tree: Add generation logging support

Generations can be logged in the console along with sampling parameters
if the user enables it in config.

Metrics are always logged at the end of each prompt. In addition,
the model endpoint tells the user if they're being logged or not
for transparancy purposes.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-12 23:43:35 -05:00
parent b364de1005
commit 083df7d585
5 changed files with 91 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
from pydantic import BaseModel, Field
from time import time
from typing import List, Optional
from gen_logging import LogConfig
class ModelCardParameters(BaseModel):
max_seq_len: Optional[int] = 4096
@@ -14,6 +15,7 @@ class ModelCard(BaseModel):
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
logging: Optional[LogConfig] = None
parameters: Optional[ModelCardParameters] = None
class ModelList(BaseModel):

View File

@@ -13,6 +13,14 @@ network:
# The port to host on (default: 5000)
port: 5000
# Options for logging
logging:
# Enable prompt logging (default: False)
prompt: False
# Enable generation parameter logging (default: False)
generation_params: False
# Options for model overrides and loading
model:
# Overrides the directory to look for models (default: models)

47
gen_logging.py Normal file
View File

@@ -0,0 +1,47 @@
from typing import Dict
from pydantic import BaseModel
# Logging preference config
class LogConfig(BaseModel):
prompt: bool = False
generation_params: bool = False
# Global reference to logging preferences
config = LogConfig()
# Wrapper to set the logging config for generations
def update_from_dict(options_dict: Dict[str, bool]):
global config
# Force bools on the dict
for value in options_dict.values():
if value is None:
value = False
config = LogConfig.parse_obj(options_dict)
def broadcast_status():
enabled = []
if config.prompt:
enabled.append("prompts")
if config.generation_params:
enabled.append("generation params")
if len(enabled) > 0:
print("Generation logging is enabled for: " + ", ".join(enabled))
else:
print("Generation logging is disabled")
# Logs generation parameters to console
def log_generation_params(**kwargs):
if config.generation_params:
print(f"Generation options: {kwargs}\n")
def log_prompt(prompt: str):
if config.prompt:
print(f"Prompt: {prompt if prompt else 'Empty'}\n")
def log_response(response: str):
if config.prompt:
print(f"Response: {response if response else 'Empty'}\n")

11
main.py
View File

@@ -1,6 +1,7 @@
import uvicorn
import yaml
import pathlib
import gen_logging
from asyncio import CancelledError
from auth import check_admin_key, check_api_key, load_auth_keys
from fastapi import FastAPI, Request, HTTPException, Depends
@@ -81,7 +82,8 @@ async def get_current_model():
rope_alpha = model_container.config.scale_alpha_value,
max_seq_len = model_container.config.max_seq_len,
prompt_template = unwrap(model_container.prompt_template, "auto")
)
),
logging = gen_logging.config
)
if model_container.draft_config:
@@ -370,6 +372,13 @@ if __name__ == "__main__":
)
config = {}
# Override the generation log options if given
log_config = unwrap(config.get("logging"), {})
if log_config:
gen_logging.update_from_dict(log_config)
gen_logging.broadcast_status()
# If an initial model name is specified, create a container and load the model
model_config = unwrap(config.get("model"), {})
if "model_name" in model_config:

View File

@@ -14,6 +14,7 @@ from exllamav2.generator import(
)
from typing import List, Optional, Union
from utils import coalesce, unwrap
from gen_logging import log_generation_params, log_prompt, log_response
# Bytes to reserve on first device when loading with auto split
auto_split_reserve_bytes = 96 * 1024**2
@@ -351,13 +352,6 @@ class ModelContainer:
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
# Ban the EOS token if specified. If not, append to stop conditions as well.
if ban_eos_token:
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
else:
stop_conditions.append(self.tokenizer.eos_token_id)
# Override sampler settings for temp = 0
if gen_settings.temperature == 0:
gen_settings.temperature = 1.0
@@ -365,7 +359,25 @@ class ModelContainer:
gen_settings.top_p = 0
gen_settings.typical = 0
# Stop conditions
# Log generation options to console
log_generation_params(
**vars(gen_settings),
token_healing = token_healing,
max_tokens = max_tokens,
stop_conditions = stop_conditions
)
# Log prompt to console
log_prompt(prompt)
# Ban the EOS token if specified. If not, append to stop conditions as well.
# Set this below logging to avoid polluting the stop strings array
if ban_eos_token:
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
else:
stop_conditions.append(self.tokenizer.eos_token_id)
# Stop conditions
self.generator.set_stop_conditions(stop_conditions)
# Tokenized context
@@ -430,9 +442,12 @@ class ModelContainer:
if eos or generated_tokens == max_tokens: break
# Print response
log_response(full_response)
elapsed_time = last_chunk_time - start_time
initial_response = f"Response: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds"
initial_response = f"Metrics: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds"
itemization = []
extra_parts = []