mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Tree: Add generation logging support
Generations can be logged in the console along with sampling parameters if the user enables it in config. Metrics are always logged at the end of each prompt. In addition, the model endpoint tells the user if they're being logged or not for transparancy purposes. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import List, Optional
|
||||
from gen_logging import LogConfig
|
||||
|
||||
class ModelCardParameters(BaseModel):
|
||||
max_seq_len: Optional[int] = 4096
|
||||
@@ -14,6 +15,7 @@ class ModelCard(BaseModel):
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
owned_by: str = "tabbyAPI"
|
||||
logging: Optional[LogConfig] = None
|
||||
parameters: Optional[ModelCardParameters] = None
|
||||
|
||||
class ModelList(BaseModel):
|
||||
|
||||
@@ -13,6 +13,14 @@ network:
|
||||
# The port to host on (default: 5000)
|
||||
port: 5000
|
||||
|
||||
# Options for logging
|
||||
logging:
|
||||
# Enable prompt logging (default: False)
|
||||
prompt: False
|
||||
|
||||
# Enable generation parameter logging (default: False)
|
||||
generation_params: False
|
||||
|
||||
# Options for model overrides and loading
|
||||
model:
|
||||
# Overrides the directory to look for models (default: models)
|
||||
|
||||
47
gen_logging.py
Normal file
47
gen_logging.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Logging preference config
|
||||
class LogConfig(BaseModel):
|
||||
prompt: bool = False
|
||||
generation_params: bool = False
|
||||
|
||||
# Global reference to logging preferences
|
||||
config = LogConfig()
|
||||
|
||||
# Wrapper to set the logging config for generations
|
||||
def update_from_dict(options_dict: Dict[str, bool]):
|
||||
global config
|
||||
|
||||
# Force bools on the dict
|
||||
for value in options_dict.values():
|
||||
if value is None:
|
||||
value = False
|
||||
|
||||
config = LogConfig.parse_obj(options_dict)
|
||||
|
||||
def broadcast_status():
|
||||
enabled = []
|
||||
if config.prompt:
|
||||
enabled.append("prompts")
|
||||
|
||||
if config.generation_params:
|
||||
enabled.append("generation params")
|
||||
|
||||
if len(enabled) > 0:
|
||||
print("Generation logging is enabled for: " + ", ".join(enabled))
|
||||
else:
|
||||
print("Generation logging is disabled")
|
||||
|
||||
# Logs generation parameters to console
|
||||
def log_generation_params(**kwargs):
|
||||
if config.generation_params:
|
||||
print(f"Generation options: {kwargs}\n")
|
||||
|
||||
def log_prompt(prompt: str):
|
||||
if config.prompt:
|
||||
print(f"Prompt: {prompt if prompt else 'Empty'}\n")
|
||||
|
||||
def log_response(response: str):
|
||||
if config.prompt:
|
||||
print(f"Response: {response if response else 'Empty'}\n")
|
||||
11
main.py
11
main.py
@@ -1,6 +1,7 @@
|
||||
import uvicorn
|
||||
import yaml
|
||||
import pathlib
|
||||
import gen_logging
|
||||
from asyncio import CancelledError
|
||||
from auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends
|
||||
@@ -81,7 +82,8 @@ async def get_current_model():
|
||||
rope_alpha = model_container.config.scale_alpha_value,
|
||||
max_seq_len = model_container.config.max_seq_len,
|
||||
prompt_template = unwrap(model_container.prompt_template, "auto")
|
||||
)
|
||||
),
|
||||
logging = gen_logging.config
|
||||
)
|
||||
|
||||
if model_container.draft_config:
|
||||
@@ -370,6 +372,13 @@ if __name__ == "__main__":
|
||||
)
|
||||
config = {}
|
||||
|
||||
# Override the generation log options if given
|
||||
log_config = unwrap(config.get("logging"), {})
|
||||
if log_config:
|
||||
gen_logging.update_from_dict(log_config)
|
||||
|
||||
gen_logging.broadcast_status()
|
||||
|
||||
# If an initial model name is specified, create a container and load the model
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
if "model_name" in model_config:
|
||||
|
||||
33
model.py
33
model.py
@@ -14,6 +14,7 @@ from exllamav2.generator import(
|
||||
)
|
||||
from typing import List, Optional, Union
|
||||
from utils import coalesce, unwrap
|
||||
from gen_logging import log_generation_params, log_prompt, log_response
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
auto_split_reserve_bytes = 96 * 1024**2
|
||||
@@ -351,13 +352,6 @@ class ModelContainer:
|
||||
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
||||
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
||||
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions as well.
|
||||
if ban_eos_token:
|
||||
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
||||
else:
|
||||
stop_conditions.append(self.tokenizer.eos_token_id)
|
||||
|
||||
# Override sampler settings for temp = 0
|
||||
if gen_settings.temperature == 0:
|
||||
gen_settings.temperature = 1.0
|
||||
@@ -365,7 +359,25 @@ class ModelContainer:
|
||||
gen_settings.top_p = 0
|
||||
gen_settings.typical = 0
|
||||
|
||||
# Stop conditions
|
||||
# Log generation options to console
|
||||
log_generation_params(
|
||||
**vars(gen_settings),
|
||||
token_healing = token_healing,
|
||||
max_tokens = max_tokens,
|
||||
stop_conditions = stop_conditions
|
||||
)
|
||||
|
||||
# Log prompt to console
|
||||
log_prompt(prompt)
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions as well.
|
||||
# Set this below logging to avoid polluting the stop strings array
|
||||
if ban_eos_token:
|
||||
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
||||
else:
|
||||
stop_conditions.append(self.tokenizer.eos_token_id)
|
||||
|
||||
# Stop conditions
|
||||
self.generator.set_stop_conditions(stop_conditions)
|
||||
|
||||
# Tokenized context
|
||||
@@ -430,9 +442,12 @@ class ModelContainer:
|
||||
|
||||
if eos or generated_tokens == max_tokens: break
|
||||
|
||||
# Print response
|
||||
log_response(full_response)
|
||||
|
||||
elapsed_time = last_chunk_time - start_time
|
||||
|
||||
initial_response = f"Response: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds"
|
||||
initial_response = f"Metrics: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds"
|
||||
itemization = []
|
||||
extra_parts = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user