Merge remote-tracking branch 'upstream/main' into HEAD

This commit is contained in:
TerminalMan
2024-09-11 15:57:18 +01:00
28 changed files with 386 additions and 171 deletions

View File

@@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local
application, it should be fine.
"""
import aiofiles
import secrets
import yaml
from fastapi import Header, HTTPException, Request
@@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None
DISABLE_AUTH: bool = False
def load_auth_keys(disable_from_config: bool):
async def load_auth_keys(disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""
global AUTH_KEYS
@@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool):
return
try:
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
auth_keys_dict = yaml.safe_load(auth_file)
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
contents = await auth_file.read()
auth_keys_dict = yaml.safe_load(contents)
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except FileNotFoundError:
new_auth_keys = AuthKeys(
@@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool):
)
AUTH_KEYS = new_auth_keys
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
new_auth_yaml = yaml.safe_dump(
AUTH_KEYS.model_dump(), default_flow_style=False
)
await auth_file.write(new_auth_yaml)
logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"

View File

@@ -13,7 +13,6 @@ from typing import Optional
from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.tabby_config import config
from common.utils import unwrap
from endpoints.utils import do_export_openapi
if not do_export_openapi:
@@ -67,7 +66,11 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
logger.info("Unloading existing model.")
await unload_model()
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
# Merge with config defaults
kwargs = {**config.model_defaults, **kwargs}
# Create a new container
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
model_type = "draft" if container.draft_config else "model"
load_status = container.load_gen(load_progress, **kwargs)
@@ -149,25 +152,6 @@ async def unload_embedding_model():
embeddings_container = None
# FIXME: Maybe make this a one-time function instead of a dynamic default
def get_config_default(key: str, model_type: str = "model"):
"""Fetches a default value from model config if allowed by the user."""
default_keys = unwrap(config.model.use_as_default, [])
# Add extra keys to defaults
default_keys.append("embeddings_device")
if key in default_keys:
# Is this a draft model load parameter?
if model_type == "draft":
return config.draft_model.get(key)
elif model_type == "embedding":
return config.embeddings.get(key)
else:
return config.model.get(key)
async def check_model_container():
"""FastAPI depends that checks if a model isn't loaded or currently loading."""

View File

@@ -1,5 +1,7 @@
"""Common functions for sampling parameters"""
import aiofiles
import json
import pathlib
import yaml
from copy import deepcopy
@@ -140,6 +142,28 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
)
dry_multiplier: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("dry_multiplier", 0.0)
)
dry_base: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("dry_base", 0.0)
)
dry_allowed_length: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0)
)
dry_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("dry_range", 0),
alias=AliasChoices("dry_range", "dry_penalty_last_n"),
description=("Aliases: dry_penalty_last_n"),
)
dry_sequence_breakers: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
)
mirostat_mode: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
)
@@ -305,6 +329,17 @@ class BaseSamplerRequest(BaseModel):
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
]
# Convert sequence breakers into an array of strings
# NOTE: This sampler sucks to parse.
if self.dry_sequence_breakers and isinstance(self.dry_sequence_breakers, str):
if not self.dry_sequence_breakers.startswith("["):
self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]"
try:
self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers)
except Exception:
self.dry_sequence_breakers = []
gen_params = {
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
@@ -335,6 +370,11 @@ class BaseSamplerRequest(BaseModel):
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"penalty_range": self.penalty_range,
"dry_multiplier": self.dry_multiplier,
"dry_base": self.dry_base,
"dry_allowed_length": self.dry_allowed_length,
"dry_sequence_breakers": self.dry_sequence_breakers,
"dry_range": self.dry_range,
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
@@ -368,14 +408,15 @@ def overrides_from_dict(new_overrides: dict):
raise TypeError("New sampler overrides must be a dict!")
def overrides_from_file(preset_name: str):
async def overrides_from_file(preset_name: str):
"""Fetches an override preset from a file"""
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
if preset_path.exists():
overrides_container.selected_preset = preset_path.stem
with open(preset_path, "r", encoding="utf8") as raw_preset:
preset = yaml.safe_load(raw_preset)
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
contents = await raw_preset.read()
preset = yaml.safe_load(contents)
overrides_from_dict(preset)
logger.info("Applied sampler overrides from file.")

View File

@@ -10,8 +10,13 @@ import common.config_models
class TabbyConfig(tabby_config_model):
def load_config(self, arguments: Optional[dict] = None):
"""load the global application config"""
# Persistent defaults
# TODO: make this pydantic?
model_defaults: dict = {}
def load(self, arguments: Optional[dict] = None):
"""Synchronously loads the global application config"""
# config is applied in order of items in the list
configs = [
@@ -28,6 +33,17 @@ class TabbyConfig(tabby_config_model):
setattr(self, field, model.parse_obj(value))
# Set model defaults dict once to prevent on-demand reconstruction
# TODO: clean this up a bit
for field in self.model.use_as_default:
if hasattr(self.model, field):
self.model_defaults[field] = getattr(config.model, field)
elif hasattr(self.draft_model, field):
self.model_defaults[field] = getattr(config.draft_model, field)
else:
# TODO: show an error
pass
def _from_file(self, config_path: pathlib.Path):
"""loads config from a given file path"""
@@ -53,7 +69,7 @@ class TabbyConfig(tabby_config_model):
config_override = unwrap(args.get("options", {}).get("config"))
if config_override:
logger.info("Config file override detected in args.")
config = self.from_file(pathlib.Path(config_override))
config = self._from_file(pathlib.Path(config_override))
return config # Return early if loading from file
for key in tabby_config_model.model_fields.keys():
@@ -85,5 +101,5 @@ class TabbyConfig(tabby_config_model):
return config
# Create an empty instance of the shared var to make sure nothing breaks
# Create an empty instance of the config class
config: TabbyConfig = TabbyConfig()

View File

@@ -1,10 +1,12 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
import aiofiles
import json
import pathlib
from importlib.metadata import version as package_version
from typing import List, Optional
from jinja2 import Template, TemplateError
from jinja2.ext import loopcontrols
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
from packaging import version
@@ -32,7 +34,10 @@ class PromptTemplate:
raw_template: str
template: Template
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True, enable_async=True
trim_blocks=True,
lstrip_blocks=True,
enable_async=True,
extensions=[loopcontrols],
)
metadata: Optional[TemplateMetadata] = None
@@ -106,32 +111,42 @@ class PromptTemplate:
self.template = self.compile(raw_template)
@classmethod
def from_file(self, prompt_template_name: str):
async def from_file(cls, template_path: pathlib.Path):
"""Get a template from a jinja file."""
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
# Add the jinja extension if it isn't provided
if template_path.suffix.endswith(".jinja"):
template_name = template_path.name.split(".jinja")[0]
else:
template_name = template_path.name
template_path = template_path.with_suffix(".jinja")
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template_stream:
return PromptTemplate(
name=prompt_template_name,
raw_template=raw_template_stream.read(),
async with aiofiles.open(
template_path, "r", encoding="utf8"
) as raw_template_stream:
contents = await raw_template_stream.read()
return cls(
name=template_name,
raw_template=contents,
)
else:
# Let the user know if the template file isn't found
raise TemplateLoadError(
f'Chat template "{prompt_template_name}" not found in files.'
f'Chat template "{template_name}" not found in files.'
)
@classmethod
def from_model_json(
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
async def from_model_json(
cls, json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
async with aiofiles.open(json_path, "r", encoding="utf8") as config_file:
contents = await config_file.read()
model_config = json.loads(contents)
chat_template = model_config.get(key)
if not chat_template:
@@ -162,7 +177,7 @@ class PromptTemplate:
)
else:
# Can safely assume the chat template is the old style
return PromptTemplate(
return cls(
name="from_tokenizer_config",
raw_template=chat_template,
)

View File

@@ -1,3 +1,4 @@
import aiofiles
import json
import pathlib
from typing import List, Optional, Union
@@ -15,15 +16,16 @@ class GenerationConfig(BaseModel):
bad_words_ids: Optional[List[List[int]]] = None
@classmethod
def from_file(self, model_directory: pathlib.Path):
async def from_file(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
generation_config_path = model_directory / "generation_config.json"
with open(
async with aiofiles.open(
generation_config_path, "r", encoding="utf8"
) as generation_config_json:
generation_config_dict = json.load(generation_config_json)
return self.model_validate(generation_config_dict)
contents = await generation_config_json.read()
generation_config_dict = json.loads(contents)
return cls.model_validate(generation_config_dict)
def eos_tokens(self):
"""Wrapper method to fetch EOS tokens."""
@@ -43,13 +45,16 @@ class HuggingFaceConfig(BaseModel):
badwordsids: Optional[str] = None
@classmethod
def from_file(self, model_directory: pathlib.Path):
async def from_file(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"
with open(hf_config_path, "r", encoding="utf8") as hf_config_json:
hf_config_dict = json.load(hf_config_json)
return self.model_validate(hf_config_dict)
async with aiofiles.open(
hf_config_path, "r", encoding="utf8"
) as hf_config_json:
contents = await hf_config_json.read()
hf_config_dict = json.loads(contents)
return cls.model_validate(hf_config_dict)
def get_badwordsids(self):
"""Wrapper method to fetch badwordsids."""