mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
Merge remote-tracking branch 'upstream/main' into HEAD
This commit is contained in:
@@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local
|
||||
application, it should be fine.
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
import secrets
|
||||
import yaml
|
||||
from fastapi import Header, HTTPException, Request
|
||||
@@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None
|
||||
DISABLE_AUTH: bool = False
|
||||
|
||||
|
||||
def load_auth_keys(disable_from_config: bool):
|
||||
async def load_auth_keys(disable_from_config: bool):
|
||||
"""Load the authentication keys from api_tokens.yml. If the file does not
|
||||
exist, generate new keys and save them to api_tokens.yml."""
|
||||
global AUTH_KEYS
|
||||
@@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool):
|
||||
return
|
||||
|
||||
try:
|
||||
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||
auth_keys_dict = yaml.safe_load(auth_file)
|
||||
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||
contents = await auth_file.read()
|
||||
auth_keys_dict = yaml.safe_load(contents)
|
||||
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
|
||||
except FileNotFoundError:
|
||||
new_auth_keys = AuthKeys(
|
||||
@@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool):
|
||||
)
|
||||
AUTH_KEYS = new_auth_keys
|
||||
|
||||
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
|
||||
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
new_auth_yaml = yaml.safe_dump(
|
||||
AUTH_KEYS.model_dump(), default_flow_style=False
|
||||
)
|
||||
await auth_file.write(new_auth_yaml)
|
||||
|
||||
logger.info(
|
||||
f"Your API key is: {AUTH_KEYS.api_key}\n"
|
||||
|
||||
@@ -13,7 +13,6 @@ from typing import Optional
|
||||
from common.logger import get_loading_progress_bar
|
||||
from common.networking import handle_request_error
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.utils import do_export_openapi
|
||||
|
||||
if not do_export_openapi:
|
||||
@@ -67,7 +66,11 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||
logger.info("Unloading existing model.")
|
||||
await unload_model()
|
||||
|
||||
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
|
||||
# Merge with config defaults
|
||||
kwargs = {**config.model_defaults, **kwargs}
|
||||
|
||||
# Create a new container
|
||||
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
||||
|
||||
model_type = "draft" if container.draft_config else "model"
|
||||
load_status = container.load_gen(load_progress, **kwargs)
|
||||
@@ -149,25 +152,6 @@ async def unload_embedding_model():
|
||||
embeddings_container = None
|
||||
|
||||
|
||||
# FIXME: Maybe make this a one-time function instead of a dynamic default
|
||||
def get_config_default(key: str, model_type: str = "model"):
|
||||
"""Fetches a default value from model config if allowed by the user."""
|
||||
|
||||
default_keys = unwrap(config.model.use_as_default, [])
|
||||
|
||||
# Add extra keys to defaults
|
||||
default_keys.append("embeddings_device")
|
||||
|
||||
if key in default_keys:
|
||||
# Is this a draft model load parameter?
|
||||
if model_type == "draft":
|
||||
return config.draft_model.get(key)
|
||||
elif model_type == "embedding":
|
||||
return config.embeddings.get(key)
|
||||
else:
|
||||
return config.model.get(key)
|
||||
|
||||
|
||||
async def check_model_container():
|
||||
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Common functions for sampling parameters"""
|
||||
|
||||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
import yaml
|
||||
from copy import deepcopy
|
||||
@@ -140,6 +142,28 @@ class BaseSamplerRequest(BaseModel):
|
||||
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
|
||||
)
|
||||
|
||||
dry_multiplier: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_multiplier", 0.0)
|
||||
)
|
||||
|
||||
dry_base: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_base", 0.0)
|
||||
)
|
||||
|
||||
dry_allowed_length: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0)
|
||||
)
|
||||
|
||||
dry_range: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_range", 0),
|
||||
alias=AliasChoices("dry_range", "dry_penalty_last_n"),
|
||||
description=("Aliases: dry_penalty_last_n"),
|
||||
)
|
||||
|
||||
dry_sequence_breakers: Optional[Union[str, List[str]]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
|
||||
)
|
||||
|
||||
mirostat_mode: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
|
||||
)
|
||||
@@ -305,6 +329,17 @@ class BaseSamplerRequest(BaseModel):
|
||||
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
|
||||
]
|
||||
|
||||
# Convert sequence breakers into an array of strings
|
||||
# NOTE: This sampler sucks to parse.
|
||||
if self.dry_sequence_breakers and isinstance(self.dry_sequence_breakers, str):
|
||||
if not self.dry_sequence_breakers.startswith("["):
|
||||
self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]"
|
||||
|
||||
try:
|
||||
self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers)
|
||||
except Exception:
|
||||
self.dry_sequence_breakers = []
|
||||
|
||||
gen_params = {
|
||||
"max_tokens": self.max_tokens,
|
||||
"min_tokens": self.min_tokens,
|
||||
@@ -335,6 +370,11 @@ class BaseSamplerRequest(BaseModel):
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"penalty_range": self.penalty_range,
|
||||
"dry_multiplier": self.dry_multiplier,
|
||||
"dry_base": self.dry_base,
|
||||
"dry_allowed_length": self.dry_allowed_length,
|
||||
"dry_sequence_breakers": self.dry_sequence_breakers,
|
||||
"dry_range": self.dry_range,
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": self.mirostat_mode == 2,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
@@ -368,14 +408,15 @@ def overrides_from_dict(new_overrides: dict):
|
||||
raise TypeError("New sampler overrides must be a dict!")
|
||||
|
||||
|
||||
def overrides_from_file(preset_name: str):
|
||||
async def overrides_from_file(preset_name: str):
|
||||
"""Fetches an override preset from a file"""
|
||||
|
||||
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
||||
if preset_path.exists():
|
||||
overrides_container.selected_preset = preset_path.stem
|
||||
with open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
preset = yaml.safe_load(raw_preset)
|
||||
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
contents = await raw_preset.read()
|
||||
preset = yaml.safe_load(contents)
|
||||
overrides_from_dict(preset)
|
||||
|
||||
logger.info("Applied sampler overrides from file.")
|
||||
|
||||
@@ -10,8 +10,13 @@ import common.config_models
|
||||
|
||||
|
||||
class TabbyConfig(tabby_config_model):
|
||||
def load_config(self, arguments: Optional[dict] = None):
|
||||
"""load the global application config"""
|
||||
|
||||
# Persistent defaults
|
||||
# TODO: make this pydantic?
|
||||
model_defaults: dict = {}
|
||||
|
||||
def load(self, arguments: Optional[dict] = None):
|
||||
"""Synchronously loads the global application config"""
|
||||
|
||||
# config is applied in order of items in the list
|
||||
configs = [
|
||||
@@ -28,6 +33,17 @@ class TabbyConfig(tabby_config_model):
|
||||
|
||||
setattr(self, field, model.parse_obj(value))
|
||||
|
||||
# Set model defaults dict once to prevent on-demand reconstruction
|
||||
# TODO: clean this up a bit
|
||||
for field in self.model.use_as_default:
|
||||
if hasattr(self.model, field):
|
||||
self.model_defaults[field] = getattr(config.model, field)
|
||||
elif hasattr(self.draft_model, field):
|
||||
self.model_defaults[field] = getattr(config.draft_model, field)
|
||||
else:
|
||||
# TODO: show an error
|
||||
pass
|
||||
|
||||
def _from_file(self, config_path: pathlib.Path):
|
||||
"""loads config from a given file path"""
|
||||
|
||||
@@ -53,7 +69,7 @@ class TabbyConfig(tabby_config_model):
|
||||
config_override = unwrap(args.get("options", {}).get("config"))
|
||||
if config_override:
|
||||
logger.info("Config file override detected in args.")
|
||||
config = self.from_file(pathlib.Path(config_override))
|
||||
config = self._from_file(pathlib.Path(config_override))
|
||||
return config # Return early if loading from file
|
||||
|
||||
for key in tabby_config_model.model_fields.keys():
|
||||
@@ -85,5 +101,5 @@ class TabbyConfig(tabby_config_model):
|
||||
return config
|
||||
|
||||
|
||||
# Create an empty instance of the shared var to make sure nothing breaks
|
||||
# Create an empty instance of the config class
|
||||
config: TabbyConfig = TabbyConfig()
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
||||
|
||||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from importlib.metadata import version as package_version
|
||||
from typing import List, Optional
|
||||
from jinja2 import Template, TemplateError
|
||||
from jinja2.ext import loopcontrols
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
@@ -32,7 +34,10 @@ class PromptTemplate:
|
||||
raw_template: str
|
||||
template: Template
|
||||
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True, lstrip_blocks=True, enable_async=True
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
enable_async=True,
|
||||
extensions=[loopcontrols],
|
||||
)
|
||||
metadata: Optional[TemplateMetadata] = None
|
||||
|
||||
@@ -106,32 +111,42 @@ class PromptTemplate:
|
||||
self.template = self.compile(raw_template)
|
||||
|
||||
@classmethod
|
||||
def from_file(self, prompt_template_name: str):
|
||||
async def from_file(cls, template_path: pathlib.Path):
|
||||
"""Get a template from a jinja file."""
|
||||
|
||||
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
||||
# Add the jinja extension if it isn't provided
|
||||
if template_path.suffix.endswith(".jinja"):
|
||||
template_name = template_path.name.split(".jinja")[0]
|
||||
else:
|
||||
template_name = template_path.name
|
||||
template_path = template_path.with_suffix(".jinja")
|
||||
|
||||
if template_path.exists():
|
||||
with open(template_path, "r", encoding="utf8") as raw_template_stream:
|
||||
return PromptTemplate(
|
||||
name=prompt_template_name,
|
||||
raw_template=raw_template_stream.read(),
|
||||
async with aiofiles.open(
|
||||
template_path, "r", encoding="utf8"
|
||||
) as raw_template_stream:
|
||||
contents = await raw_template_stream.read()
|
||||
return cls(
|
||||
name=template_name,
|
||||
raw_template=contents,
|
||||
)
|
||||
else:
|
||||
# Let the user know if the template file isn't found
|
||||
raise TemplateLoadError(
|
||||
f'Chat template "{prompt_template_name}" not found in files.'
|
||||
f'Chat template "{template_name}" not found in files.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_model_json(
|
||||
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
async def from_model_json(
|
||||
cls, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
):
|
||||
"""Get a template from a JSON file. Requires a key and template name"""
|
||||
if not json_path.exists():
|
||||
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
||||
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
async with aiofiles.open(json_path, "r", encoding="utf8") as config_file:
|
||||
contents = await config_file.read()
|
||||
model_config = json.loads(contents)
|
||||
chat_template = model_config.get(key)
|
||||
|
||||
if not chat_template:
|
||||
@@ -162,7 +177,7 @@ class PromptTemplate:
|
||||
)
|
||||
else:
|
||||
# Can safely assume the chat template is the old style
|
||||
return PromptTemplate(
|
||||
return cls(
|
||||
name="from_tokenizer_config",
|
||||
raw_template=chat_template,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from typing import List, Optional, Union
|
||||
@@ -15,15 +16,16 @@ class GenerationConfig(BaseModel):
|
||||
bad_words_ids: Optional[List[List[int]]] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(self, model_directory: pathlib.Path):
|
||||
async def from_file(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
with open(
|
||||
async with aiofiles.open(
|
||||
generation_config_path, "r", encoding="utf8"
|
||||
) as generation_config_json:
|
||||
generation_config_dict = json.load(generation_config_json)
|
||||
return self.model_validate(generation_config_dict)
|
||||
contents = await generation_config_json.read()
|
||||
generation_config_dict = json.loads(contents)
|
||||
return cls.model_validate(generation_config_dict)
|
||||
|
||||
def eos_tokens(self):
|
||||
"""Wrapper method to fetch EOS tokens."""
|
||||
@@ -43,13 +45,16 @@ class HuggingFaceConfig(BaseModel):
|
||||
badwordsids: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(self, model_directory: pathlib.Path):
|
||||
async def from_file(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
hf_config_path = model_directory / "config.json"
|
||||
with open(hf_config_path, "r", encoding="utf8") as hf_config_json:
|
||||
hf_config_dict = json.load(hf_config_json)
|
||||
return self.model_validate(hf_config_dict)
|
||||
async with aiofiles.open(
|
||||
hf_config_path, "r", encoding="utf8"
|
||||
) as hf_config_json:
|
||||
contents = await hf_config_json.read()
|
||||
hf_config_dict = json.loads(contents)
|
||||
return cls.model_validate(hf_config_dict)
|
||||
|
||||
def get_badwordsids(self):
|
||||
"""Wrapper method to fetch badwordsids."""
|
||||
|
||||
Reference in New Issue
Block a user