mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
config is now backed by pydantic (WIP)
- add models for config options - add function to regenerate config.yml - replace references to config with pydantic compatible references - remove unnecessary unwrap() statements TODO: - auto generate env vars - auto generate argparse - test loading a model
This commit is contained in:
34
main.py
34
main.py
@@ -27,8 +27,8 @@ if not do_export_openapi:
|
||||
async def entrypoint_async():
|
||||
"""Async entry function for program startup"""
|
||||
|
||||
host = unwrap(config.network.get("host"), "127.0.0.1")
|
||||
port = unwrap(config.network.get("port"), 5000)
|
||||
host = config.network.host
|
||||
port = config.network.port
|
||||
|
||||
# Check if the port is available and attempt to bind a fallback
|
||||
if is_port_in_use(port):
|
||||
@@ -50,16 +50,12 @@ async def entrypoint_async():
|
||||
port = fallback_port
|
||||
|
||||
# Initialize auth keys
|
||||
load_auth_keys(unwrap(config.network.get("disable_auth"), False))
|
||||
|
||||
# Override the generation log options if given
|
||||
if config.logging:
|
||||
gen_logging.update_from_dict(config.logging)
|
||||
load_auth_keys(config.network.disable_auth)
|
||||
|
||||
gen_logging.broadcast_status()
|
||||
|
||||
# Set sampler parameter overrides if provided
|
||||
sampling_override_preset = config.sampling.get("override_preset")
|
||||
sampling_override_preset = config.sampling.override_preset
|
||||
if sampling_override_preset:
|
||||
try:
|
||||
sampling.overrides_from_file(sampling_override_preset)
|
||||
@@ -68,25 +64,23 @@ async def entrypoint_async():
|
||||
|
||||
# If an initial model name is specified, create a container
|
||||
# and load the model
|
||||
model_name = config.model.get("model_name")
|
||||
model_name = config.model.model_name
|
||||
if model_name:
|
||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(config.model.model_dir)
|
||||
model_path = model_path / model_name
|
||||
|
||||
await model.load_model(model_path.resolve(), **config.model)
|
||||
|
||||
# Load loras after loading the model
|
||||
if config.lora.get("loras"):
|
||||
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
if config.lora.loras:
|
||||
lora_dir = pathlib.Path(config.lora.lora_dir)
|
||||
await model.container.load_loras(lora_dir.resolve(), **config.lora)
|
||||
|
||||
# If an initial embedding model name is specified, create a separate container
|
||||
# and load the model
|
||||
embedding_model_name = config.embeddings.get("embedding_model_name")
|
||||
embedding_model_name = config.embeddings.embedding_model_name
|
||||
if embedding_model_name:
|
||||
embedding_model_path = pathlib.Path(
|
||||
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = pathlib.Path(config.embeddings.embedding_model_dir)
|
||||
embedding_model_path = embedding_model_path / embedding_model_name
|
||||
|
||||
try:
|
||||
@@ -124,7 +118,7 @@ def entrypoint(arguments: Optional[dict] = None):
|
||||
# Check exllamav2 version and give a descriptive error if it's too old
|
||||
# Skip if launching unsafely
|
||||
print(f"MAIN.PY {config=}")
|
||||
if unwrap(config.developer.get("unsafe_launch"), False):
|
||||
if config.developer.unsafe_launch:
|
||||
logger.warning(
|
||||
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
||||
"If you aren't a developer, please keep this off!"
|
||||
@@ -133,12 +127,12 @@ def entrypoint(arguments: Optional[dict] = None):
|
||||
check_exllama_version()
|
||||
|
||||
# Enable CUDA malloc backend
|
||||
if unwrap(config.developer.get("cuda_malloc_backend"), False):
|
||||
if config.developer.cuda_malloc_backend:
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
||||
logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.")
|
||||
|
||||
# Use Uvloop/Winloop
|
||||
if unwrap(config.developer.get("uvloop"), False):
|
||||
if config.developer.uvloop:
|
||||
if platform.system() == "Windows":
|
||||
from winloop import install
|
||||
else:
|
||||
@@ -150,7 +144,7 @@ def entrypoint(arguments: Optional[dict] = None):
|
||||
logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.")
|
||||
|
||||
# Set the process priority
|
||||
if unwrap(config.developer.get("realtime_process_priority"), False):
|
||||
if config.developer.realtime_process_priority:
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
|
||||
Reference in New Issue
Block a user