mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-27 09:41:54 +00:00
Merge pull request #185 from SecretiveShell/refactor-config-loading
Refactor config loading
This commit is contained in:
107
common/config.py
107
common/config.py
@@ -1,107 +0,0 @@
|
|||||||
import yaml
|
|
||||||
import pathlib
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from common.utils import unwrap
|
|
||||||
|
|
||||||
# Global config dictionary constant
|
|
||||||
GLOBAL_CONFIG: dict = {}
|
|
||||||
|
|
||||||
|
|
||||||
def from_file(config_path: pathlib.Path):
|
|
||||||
"""Sets the global config from a given file path"""
|
|
||||||
global GLOBAL_CONFIG
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
|
||||||
GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {})
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(
|
|
||||||
"The YAML config couldn't load because of the following error: "
|
|
||||||
f"\n\n{exc}"
|
|
||||||
"\n\nTabbyAPI will start anyway and not parse this config file."
|
|
||||||
)
|
|
||||||
GLOBAL_CONFIG = {}
|
|
||||||
|
|
||||||
|
|
||||||
def from_args(args: dict):
|
|
||||||
"""Overrides the config based on a dict representation of args"""
|
|
||||||
|
|
||||||
config_override = unwrap(args.get("options", {}).get("config"))
|
|
||||||
if config_override:
|
|
||||||
logger.info("Attempting to override config.yml from args.")
|
|
||||||
from_file(pathlib.Path(config_override))
|
|
||||||
return
|
|
||||||
|
|
||||||
# Network config
|
|
||||||
network_override = args.get("network")
|
|
||||||
if network_override:
|
|
||||||
cur_network_config = network_config()
|
|
||||||
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override}
|
|
||||||
|
|
||||||
# Model config
|
|
||||||
model_override = args.get("model")
|
|
||||||
if model_override:
|
|
||||||
cur_model_config = model_config()
|
|
||||||
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
|
|
||||||
|
|
||||||
# Generation Logging config
|
|
||||||
logging_override = args.get("logging")
|
|
||||||
if logging_override:
|
|
||||||
cur_logging_config = logging_config()
|
|
||||||
GLOBAL_CONFIG["logging"] = {
|
|
||||||
**cur_logging_config,
|
|
||||||
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
|
|
||||||
}
|
|
||||||
|
|
||||||
developer_override = args.get("developer")
|
|
||||||
if developer_override:
|
|
||||||
cur_developer_config = developer_config()
|
|
||||||
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
|
||||||
|
|
||||||
embeddings_override = args.get("embeddings")
|
|
||||||
if embeddings_override:
|
|
||||||
cur_embeddings_config = embeddings_config()
|
|
||||||
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
|
|
||||||
|
|
||||||
|
|
||||||
def sampling_config():
|
|
||||||
"""Returns the sampling parameter config from the global config"""
|
|
||||||
return unwrap(GLOBAL_CONFIG.get("sampling"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def model_config():
|
|
||||||
"""Returns the model config from the global config"""
|
|
||||||
return unwrap(GLOBAL_CONFIG.get("model"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def draft_model_config():
|
|
||||||
"""Returns the draft model config from the global config"""
|
|
||||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
|
||||||
return unwrap(model_config.get("draft"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def lora_config():
|
|
||||||
"""Returns the lora config from the global config"""
|
|
||||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
|
||||||
return unwrap(model_config.get("lora"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def network_config():
|
|
||||||
"""Returns the network config from the global config"""
|
|
||||||
return unwrap(GLOBAL_CONFIG.get("network"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def logging_config():
|
|
||||||
"""Returns the logging config from the global config"""
|
|
||||||
return unwrap(GLOBAL_CONFIG.get("logging"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def developer_config():
|
|
||||||
"""Returns the developer specific config from the global config"""
|
|
||||||
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def embeddings_config():
|
|
||||||
"""Returns the embeddings config from the global config"""
|
|
||||||
return unwrap(GLOBAL_CONFIG.get("embeddings"), {})
|
|
||||||
@@ -10,8 +10,8 @@ from loguru import logger
|
|||||||
from rich.progress import Progress
|
from rich.progress import Progress
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from common.config import lora_config, model_config
|
|
||||||
from common.logger import get_progress_bar
|
from common.logger import get_progress_bar
|
||||||
|
from common.tabby_config import config
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
|
|
||||||
|
|
||||||
@@ -76,9 +76,9 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str
|
|||||||
"""Gets the download folder for the repo."""
|
"""Gets the download folder for the repo."""
|
||||||
|
|
||||||
if repo_type == "lora":
|
if repo_type == "lora":
|
||||||
download_path = pathlib.Path(lora_config().get("lora_dir") or "loras")
|
download_path = pathlib.Path(config.lora.get("lora_dir") or "loras")
|
||||||
else:
|
else:
|
||||||
download_path = pathlib.Path(model_config().get("model_dir") or "models")
|
download_path = pathlib.Path(config.model.get("model_dir") or "models")
|
||||||
|
|
||||||
download_path = download_path / (folder_name or repo_id.split("/")[-1])
|
download_path = download_path / (folder_name or repo_id.split("/")[-1])
|
||||||
return download_path
|
return download_path
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ from fastapi import HTTPException
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from common import config
|
|
||||||
from common.logger import get_loading_progress_bar
|
from common.logger import get_loading_progress_bar
|
||||||
from common.networking import handle_request_error
|
from common.networking import handle_request_error
|
||||||
|
from common.tabby_config import config
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
from endpoints.utils import do_export_openapi
|
from endpoints.utils import do_export_openapi
|
||||||
|
|
||||||
@@ -153,8 +153,7 @@ async def unload_embedding_model():
|
|||||||
def get_config_default(key: str, model_type: str = "model"):
|
def get_config_default(key: str, model_type: str = "model"):
|
||||||
"""Fetches a default value from model config if allowed by the user."""
|
"""Fetches a default value from model config if allowed by the user."""
|
||||||
|
|
||||||
model_config = config.model_config()
|
default_keys = unwrap(config.model.get("use_as_default"), [])
|
||||||
default_keys = unwrap(model_config.get("use_as_default"), [])
|
|
||||||
|
|
||||||
# Add extra keys to defaults
|
# Add extra keys to defaults
|
||||||
default_keys.append("embeddings_device")
|
default_keys.append("embeddings_device")
|
||||||
@@ -162,13 +161,11 @@ def get_config_default(key: str, model_type: str = "model"):
|
|||||||
if key in default_keys:
|
if key in default_keys:
|
||||||
# Is this a draft model load parameter?
|
# Is this a draft model load parameter?
|
||||||
if model_type == "draft":
|
if model_type == "draft":
|
||||||
draft_config = config.draft_model_config()
|
return config.draft_model.get(key)
|
||||||
return draft_config.get(key)
|
|
||||||
elif model_type == "embedding":
|
elif model_type == "embedding":
|
||||||
embeddings_config = config.embeddings_config()
|
return config.embeddings.get(key)
|
||||||
return embeddings_config.get(key)
|
|
||||||
else:
|
else:
|
||||||
return model_config.get(key)
|
return config.model.get(key)
|
||||||
|
|
||||||
|
|
||||||
async def check_model_container():
|
async def check_model_container():
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from pydantic import BaseModel
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from common import config
|
from common.tabby_config import config
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
|
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ def handle_request_error(message: str, exc_info: bool = True):
|
|||||||
"""Log a request error to the console."""
|
"""Log a request error to the console."""
|
||||||
|
|
||||||
trace = traceback.format_exc()
|
trace = traceback.format_exc()
|
||||||
send_trace = unwrap(config.network_config().get("send_tracebacks"), False)
|
send_trace = unwrap(config.network.get("send_tracebacks"), False)
|
||||||
|
|
||||||
error_message = TabbyRequestErrorMessage(
|
error_message = TabbyRequestErrorMessage(
|
||||||
message=message, trace=trace if send_trace else None
|
message=message, trace=trace if send_trace else None
|
||||||
@@ -134,7 +134,7 @@ def get_global_depends():
|
|||||||
|
|
||||||
depends = [Depends(add_request_id)]
|
depends = [Depends(add_request_id)]
|
||||||
|
|
||||||
if config.logging_config().get("requests"):
|
if config.logging.get("requests"):
|
||||||
depends.append(Depends(log_request))
|
depends.append(Depends(log_request))
|
||||||
|
|
||||||
return depends
|
return depends
|
||||||
|
|||||||
88
common/tabby_config.py
Normal file
88
common/tabby_config.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import yaml
|
||||||
|
import pathlib
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from common.utils import unwrap, merge_dicts
|
||||||
|
|
||||||
|
|
||||||
|
class TabbyConfig:
|
||||||
|
network: dict = {}
|
||||||
|
logging: dict = {}
|
||||||
|
model: dict = {}
|
||||||
|
draft_model: dict = {}
|
||||||
|
lora: dict = {}
|
||||||
|
sampling: dict = {}
|
||||||
|
developer: dict = {}
|
||||||
|
embeddings: dict = {}
|
||||||
|
|
||||||
|
def load(self, arguments: Optional[dict] = None):
|
||||||
|
"""load the global application config"""
|
||||||
|
|
||||||
|
# config is applied in order of items in the list
|
||||||
|
configs = [
|
||||||
|
self._from_file(pathlib.Path("config.yml")),
|
||||||
|
self._from_args(unwrap(arguments, {})),
|
||||||
|
]
|
||||||
|
|
||||||
|
merged_config = merge_dicts(*configs)
|
||||||
|
|
||||||
|
self.network = unwrap(merged_config.get("network"), {})
|
||||||
|
self.logging = unwrap(merged_config.get("logging"), {})
|
||||||
|
self.model = unwrap(merged_config.get("model"), {})
|
||||||
|
self.draft_model = unwrap(merged_config.get("draft"), {})
|
||||||
|
self.lora = unwrap(merged_config.get("draft"), {})
|
||||||
|
self.sampling = unwrap(merged_config.get("sampling"), {})
|
||||||
|
self.developer = unwrap(merged_config.get("developer"), {})
|
||||||
|
self.embeddings = unwrap(merged_config.get("embeddings"), {})
|
||||||
|
|
||||||
|
def _from_file(self, config_path: pathlib.Path):
|
||||||
|
"""loads config from a given file path"""
|
||||||
|
|
||||||
|
# try loading from file
|
||||||
|
try:
|
||||||
|
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
||||||
|
return unwrap(yaml.safe_load(config_file), {})
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.info(f"The '{config_path.name}' file cannot be found")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
f"The YAML config from '{config_path.name}' couldn't load because of "
|
||||||
|
f"the following error:\n\n{exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# if no config file was loaded
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _from_args(self, args: dict):
|
||||||
|
"""loads config from the provided arguments"""
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
config_override = unwrap(args.get("options", {}).get("config"))
|
||||||
|
if config_override:
|
||||||
|
logger.info("Config file override detected in args.")
|
||||||
|
config = self.from_file(pathlib.Path(config_override))
|
||||||
|
return config # Return early if loading from file
|
||||||
|
|
||||||
|
for key in ["network", "model", "logging", "developer", "embeddings"]:
|
||||||
|
override = args.get(key)
|
||||||
|
if override:
|
||||||
|
if key == "logging":
|
||||||
|
# Strip the "log_" prefix from logging keys if present
|
||||||
|
override = {k.replace("log_", ""): v for k, v in override.items()}
|
||||||
|
config[key] = override
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _from_environment(self):
|
||||||
|
"""loads configuration from environment variables"""
|
||||||
|
|
||||||
|
# TODO: load config from environment variables
|
||||||
|
# this means that we can have host default to 0.0.0.0 in docker for example
|
||||||
|
# this would also mean that docker containers no longer require a non
|
||||||
|
# default config file to be used
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Create an empty instance of the config class
|
||||||
|
config: TabbyConfig = TabbyConfig()
|
||||||
@@ -20,6 +20,25 @@ def prune_dict(input_dict):
|
|||||||
return {k: v for k, v in input_dict.items() if v is not None}
|
return {k: v for k, v in input_dict.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dict(dict1, dict2):
|
||||||
|
"""Merge 2 dictionaries"""
|
||||||
|
for key, value in dict2.items():
|
||||||
|
if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict):
|
||||||
|
merge_dict(dict1[key], value)
|
||||||
|
else:
|
||||||
|
dict1[key] = value
|
||||||
|
return dict1
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dicts(*dicts):
|
||||||
|
"""Merge an arbitrary amount of dictionaries"""
|
||||||
|
result = {}
|
||||||
|
for dictionary in dicts:
|
||||||
|
result = merge_dict(result, dictionary)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def flat_map(input_list):
|
def flat_map(input_list):
|
||||||
"""Flattens a list of lists into a single list."""
|
"""Flattens a list of lists into a single list."""
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from sys import maxsize
|
from sys import maxsize
|
||||||
|
|
||||||
from common import config, model
|
from common import model
|
||||||
from common.auth import check_api_key
|
from common.auth import check_api_key
|
||||||
from common.model import check_embeddings_container, check_model_container
|
from common.model import check_embeddings_container, check_model_container
|
||||||
from common.networking import handle_request_error, run_with_request_disconnect
|
from common.networking import handle_request_error, run_with_request_disconnect
|
||||||
|
from common.tabby_config import config
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||||
from endpoints.OAI.types.chat_completion import (
|
from endpoints.OAI.types.chat_completion import (
|
||||||
@@ -64,7 +65,7 @@ async def completion_request(
|
|||||||
data.prompt = "\n".join(data.prompt)
|
data.prompt = "\n".join(data.prompt)
|
||||||
|
|
||||||
disable_request_streaming = unwrap(
|
disable_request_streaming = unwrap(
|
||||||
config.developer_config().get("disable_request_streaming"), False
|
config.developer.get("disable_request_streaming"), False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set an empty JSON schema if the request wants a JSON response
|
# Set an empty JSON schema if the request wants a JSON response
|
||||||
@@ -128,7 +129,7 @@ async def chat_completion_request(
|
|||||||
data.json_schema = {"type": "object"}
|
data.json_schema = {"type": "object"}
|
||||||
|
|
||||||
disable_request_streaming = unwrap(
|
disable_request_streaming = unwrap(
|
||||||
config.developer_config().get("disable_request_streaming"), False
|
config.developer.get("disable_request_streaming"), False
|
||||||
)
|
)
|
||||||
|
|
||||||
if data.stream and not disable_request_streaming:
|
if data.stream and not disable_request_streaming:
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ from sys import maxsize
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
|
|
||||||
from common import config, model, sampling
|
from common import model, sampling
|
||||||
from common.auth import check_admin_key, check_api_key, get_key_permission
|
from common.auth import check_admin_key, check_api_key, get_key_permission
|
||||||
from common.downloader import hf_repo_download
|
from common.downloader import hf_repo_download
|
||||||
from common.model import check_embeddings_container, check_model_container
|
from common.model import check_embeddings_container, check_model_container
|
||||||
from common.networking import handle_request_error, run_with_request_disconnect
|
from common.networking import handle_request_error, run_with_request_disconnect
|
||||||
|
from common.tabby_config import config
|
||||||
from common.templating import PromptTemplate, get_all_templates
|
from common.templating import PromptTemplate, get_all_templates
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
from endpoints.core.types.auth import AuthPermissionResponse
|
from endpoints.core.types.auth import AuthPermissionResponse
|
||||||
@@ -61,18 +62,17 @@ async def list_models(request: Request) -> ModelList:
|
|||||||
Requires an admin key to see all models.
|
Requires an admin key to see all models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = config.model_config()
|
model_dir = unwrap(config.model.get("model_dir"), "models")
|
||||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
|
||||||
model_path = pathlib.Path(model_dir)
|
model_path = pathlib.Path(model_dir)
|
||||||
|
|
||||||
draft_model_dir = config.draft_model_config().get("draft_model_dir")
|
draft_model_dir = config.draft_model.get("draft_model_dir")
|
||||||
|
|
||||||
if get_key_permission(request) == "admin":
|
if get_key_permission(request) == "admin":
|
||||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||||
else:
|
else:
|
||||||
models = await get_current_model_list()
|
models = await get_current_model_list()
|
||||||
|
|
||||||
if unwrap(model_config.get("use_dummy_models"), False):
|
if unwrap(config.model.get("use_dummy_models"), False):
|
||||||
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||||
|
|
||||||
return models
|
return models
|
||||||
@@ -98,9 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if get_key_permission(request) == "admin":
|
if get_key_permission(request) == "admin":
|
||||||
draft_model_dir = unwrap(
|
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||||
config.draft_model_config().get("draft_model_dir"), "models"
|
|
||||||
)
|
|
||||||
draft_model_path = pathlib.Path(draft_model_dir)
|
draft_model_path = pathlib.Path(draft_model_dir)
|
||||||
|
|
||||||
models = get_model_list(draft_model_path.resolve())
|
models = get_model_list(draft_model_path.resolve())
|
||||||
@@ -124,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
|||||||
|
|
||||||
raise HTTPException(400, error_message)
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
|
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||||
model_path = model_path / data.name
|
model_path = model_path / data.name
|
||||||
|
|
||||||
draft_model_path = None
|
draft_model_path = None
|
||||||
@@ -137,9 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
|||||||
|
|
||||||
raise HTTPException(400, error_message)
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
draft_model_path = unwrap(
|
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||||
config.draft_model_config().get("draft_model_dir"), "models"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
@@ -196,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if get_key_permission(request) == "admin":
|
if get_key_permission(request) == "admin":
|
||||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||||
loras = get_lora_list(lora_path.resolve())
|
loras = get_lora_list(lora_path.resolve())
|
||||||
else:
|
else:
|
||||||
loras = get_active_loras()
|
loras = get_active_loras()
|
||||||
@@ -231,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
|
|||||||
|
|
||||||
raise HTTPException(400, error_message)
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||||
if not lora_dir.exists():
|
if not lora_dir.exists():
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
"A parent lora directory does not exist for load. Check your config.yml?",
|
"A parent lora directory does not exist for load. Check your config.yml?",
|
||||||
@@ -271,7 +267,7 @@ async def list_embedding_models(request: Request) -> ModelList:
|
|||||||
|
|
||||||
if get_key_permission(request) == "admin":
|
if get_key_permission(request) == "admin":
|
||||||
embedding_model_dir = unwrap(
|
embedding_model_dir = unwrap(
|
||||||
config.embeddings_config().get("embedding_model_dir"), "models"
|
config.embeddings.get("embedding_model_dir"), "models"
|
||||||
)
|
)
|
||||||
embedding_model_path = pathlib.Path(embedding_model_dir)
|
embedding_model_path = pathlib.Path(embedding_model_dir)
|
||||||
|
|
||||||
@@ -307,7 +303,7 @@ async def load_embedding_model(
|
|||||||
raise HTTPException(400, error_message)
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
embedding_model_dir = pathlib.Path(
|
embedding_model_dir = pathlib.Path(
|
||||||
unwrap(config.model_config().get("embedding_model_dir"), "models")
|
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
||||||
)
|
)
|
||||||
embedding_model_path = embedding_model_dir / data.name
|
embedding_model_path = embedding_model_dir / data.name
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from common import config
|
|
||||||
from common.logger import UVICORN_LOG_CONFIG
|
from common.logger import UVICORN_LOG_CONFIG
|
||||||
from common.networking import get_global_depends
|
from common.networking import get_global_depends
|
||||||
|
from common.tabby_config import config
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
from endpoints.Kobold import router as KoboldRouter
|
from endpoints.Kobold import router as KoboldRouter
|
||||||
from endpoints.OAI import router as OAIRouter
|
from endpoints.OAI import router as OAIRouter
|
||||||
@@ -36,7 +36,7 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
api_servers = unwrap(config.network_config().get("api_servers"), [])
|
api_servers = unwrap(config.network.get("api_servers"), [])
|
||||||
|
|
||||||
# Map for API id to server router
|
# Map for API id to server router
|
||||||
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
|
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
|
||||||
|
|||||||
58
main.py
58
main.py
@@ -9,12 +9,13 @@ import signal
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from common import config, gen_logging, sampling, model
|
from common import gen_logging, sampling, model
|
||||||
from common.args import convert_args_to_dict, init_argparser
|
from common.args import convert_args_to_dict, init_argparser
|
||||||
from common.auth import load_auth_keys
|
from common.auth import load_auth_keys
|
||||||
from common.logger import setup_logger
|
from common.logger import setup_logger
|
||||||
from common.networking import is_port_in_use
|
from common.networking import is_port_in_use
|
||||||
from common.signals import signal_handler
|
from common.signals import signal_handler
|
||||||
|
from common.tabby_config import config
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
from endpoints.server import export_openapi, start_api
|
from endpoints.server import export_openapi, start_api
|
||||||
from endpoints.utils import do_export_openapi
|
from endpoints.utils import do_export_openapi
|
||||||
@@ -26,10 +27,8 @@ if not do_export_openapi:
|
|||||||
async def entrypoint_async():
|
async def entrypoint_async():
|
||||||
"""Async entry function for program startup"""
|
"""Async entry function for program startup"""
|
||||||
|
|
||||||
network_config = config.network_config()
|
host = unwrap(config.network.get("host"), "127.0.0.1")
|
||||||
|
port = unwrap(config.network.get("port"), 5000)
|
||||||
host = unwrap(network_config.get("host"), "127.0.0.1")
|
|
||||||
port = unwrap(network_config.get("port"), 5000)
|
|
||||||
|
|
||||||
# Check if the port is available and attempt to bind a fallback
|
# Check if the port is available and attempt to bind a fallback
|
||||||
if is_port_in_use(port):
|
if is_port_in_use(port):
|
||||||
@@ -51,18 +50,16 @@ async def entrypoint_async():
|
|||||||
port = fallback_port
|
port = fallback_port
|
||||||
|
|
||||||
# Initialize auth keys
|
# Initialize auth keys
|
||||||
load_auth_keys(unwrap(network_config.get("disable_auth"), False))
|
load_auth_keys(unwrap(config.network.get("disable_auth"), False))
|
||||||
|
|
||||||
# Override the generation log options if given
|
# Override the generation log options if given
|
||||||
log_config = config.logging_config()
|
if config.logging:
|
||||||
if log_config:
|
gen_logging.update_from_dict(config.logging)
|
||||||
gen_logging.update_from_dict(log_config)
|
|
||||||
|
|
||||||
gen_logging.broadcast_status()
|
gen_logging.broadcast_status()
|
||||||
|
|
||||||
# Set sampler parameter overrides if provided
|
# Set sampler parameter overrides if provided
|
||||||
sampling_config = config.sampling_config()
|
sampling_override_preset = config.sampling.get("override_preset")
|
||||||
sampling_override_preset = sampling_config.get("override_preset")
|
|
||||||
if sampling_override_preset:
|
if sampling_override_preset:
|
||||||
try:
|
try:
|
||||||
sampling.overrides_from_file(sampling_override_preset)
|
sampling.overrides_from_file(sampling_override_preset)
|
||||||
@@ -71,32 +68,29 @@ async def entrypoint_async():
|
|||||||
|
|
||||||
# If an initial model name is specified, create a container
|
# If an initial model name is specified, create a container
|
||||||
# and load the model
|
# and load the model
|
||||||
model_config = config.model_config()
|
model_name = config.model.get("model_name")
|
||||||
model_name = model_config.get("model_name")
|
|
||||||
if model_name:
|
if model_name:
|
||||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||||
model_path = model_path / model_name
|
model_path = model_path / model_name
|
||||||
|
|
||||||
await model.load_model(model_path.resolve(), **model_config)
|
await model.load_model(model_path.resolve(), **config.model)
|
||||||
|
|
||||||
# Load loras after loading the model
|
# Load loras after loading the model
|
||||||
lora_config = config.lora_config()
|
if config.lora.get("loras"):
|
||||||
if lora_config.get("loras"):
|
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
await model.container.load_loras(lora_dir.resolve(), **config.lora)
|
||||||
await model.container.load_loras(lora_dir.resolve(), **lora_config)
|
|
||||||
|
|
||||||
# If an initial embedding model name is specified, create a separate container
|
# If an initial embedding model name is specified, create a separate container
|
||||||
# and load the model
|
# and load the model
|
||||||
embedding_config = config.embeddings_config()
|
embedding_model_name = config.embeddings.get("embedding_model_name")
|
||||||
embedding_model_name = embedding_config.get("embedding_model_name")
|
|
||||||
if embedding_model_name:
|
if embedding_model_name:
|
||||||
embedding_model_path = pathlib.Path(
|
embedding_model_path = pathlib.Path(
|
||||||
unwrap(embedding_config.get("embedding_model_dir"), "models")
|
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
||||||
)
|
)
|
||||||
embedding_model_path = embedding_model_path / embedding_model_name
|
embedding_model_path = embedding_model_path / embedding_model_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await model.load_embedding_model(embedding_model_path, **embedding_config)
|
await model.load_embedding_model(embedding_model_path, **config.embeddings)
|
||||||
except ImportError as ex:
|
except ImportError as ex:
|
||||||
logger.error(ex.msg)
|
logger.error(ex.msg)
|
||||||
|
|
||||||
@@ -110,15 +104,13 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
# Load from YAML config
|
|
||||||
config.from_file(pathlib.Path("config.yml"))
|
|
||||||
|
|
||||||
# Parse and override config from args
|
# Parse and override config from args
|
||||||
if arguments is None:
|
if arguments is None:
|
||||||
parser = init_argparser()
|
parser = init_argparser()
|
||||||
arguments = convert_args_to_dict(parser.parse_args(), parser)
|
arguments = convert_args_to_dict(parser.parse_args(), parser)
|
||||||
|
|
||||||
config.from_args(arguments)
|
# load config
|
||||||
|
config.load(arguments)
|
||||||
|
|
||||||
if do_export_openapi:
|
if do_export_openapi:
|
||||||
openapi_json = export_openapi()
|
openapi_json = export_openapi()
|
||||||
@@ -129,12 +121,10 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
developer_config = config.developer_config()
|
|
||||||
|
|
||||||
# Check exllamav2 version and give a descriptive error if it's too old
|
# Check exllamav2 version and give a descriptive error if it's too old
|
||||||
# Skip if launching unsafely
|
# Skip if launching unsafely
|
||||||
|
print(f"MAIN.PY {config=}")
|
||||||
if unwrap(developer_config.get("unsafe_launch"), False):
|
if unwrap(config.developer.get("unsafe_launch"), False):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
||||||
"If you aren't a developer, please keep this off!"
|
"If you aren't a developer, please keep this off!"
|
||||||
@@ -143,12 +133,12 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||||||
check_exllama_version()
|
check_exllama_version()
|
||||||
|
|
||||||
# Enable CUDA malloc backend
|
# Enable CUDA malloc backend
|
||||||
if unwrap(developer_config.get("cuda_malloc_backend"), False):
|
if unwrap(config.developer.get("cuda_malloc_backend"), False):
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
||||||
logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.")
|
logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.")
|
||||||
|
|
||||||
# Use Uvloop/Winloop
|
# Use Uvloop/Winloop
|
||||||
if unwrap(developer_config.get("uvloop"), False):
|
if unwrap(config.developer.get("uvloop"), False):
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
from winloop import install
|
from winloop import install
|
||||||
else:
|
else:
|
||||||
@@ -160,7 +150,7 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||||||
logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.")
|
logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.")
|
||||||
|
|
||||||
# Set the process priority
|
# Set the process priority
|
||||||
if unwrap(developer_config.get("realtime_process_priority"), False):
|
if unwrap(config.developer.get("realtime_process_priority"), False):
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
current_process = psutil.Process(os.getpid())
|
current_process = psutil.Process(os.getpid())
|
||||||
|
|||||||
Reference in New Issue
Block a user