mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-29 02:31:48 +00:00
Tree: Refactor code organization
Move common functions into their own folder and refactor the backends to use their own folder as well. Also cleanup imports and alphabetize import statments themselves. Finally, move colab and docker into their own folders as well. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
from uuid import uuid4
|
|
||||||
from time import time
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from time import time
|
||||||
from typing import Union, List, Optional, Dict
|
from typing import Union, List, Optional, Dict
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from OAI.types.common import UsageStats, CommonCompletionRequest
|
from OAI.types.common import UsageStats, CommonCompletionRequest
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
""" Completion API protocols """
|
""" Completion API protocols """
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from time import time
|
from time import time
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats
|
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
""" Lora types """
|
""" Lora types """
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class LoraCard(BaseModel):
|
class LoraCard(BaseModel):
|
||||||
"""Represents a single Lora card."""
|
"""Represents a single Lora card."""
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
""" Contains model card types. """
|
""" Contains model card types. """
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
from time import time
|
from time import time
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict
|
from common.gen_logging import LogPreferences
|
||||||
|
|
||||||
from gen_logging import LogPreferences
|
|
||||||
|
|
||||||
|
|
||||||
class ModelCardParameters(BaseModel):
|
class ModelCardParameters(BaseModel):
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
""" Tokenization types """
|
""" Tokenization types """
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class CommonTokenRequest(BaseModel):
|
class CommonTokenRequest(BaseModel):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from common.utils import unwrap
|
||||||
from OAI.types.chat_completion import (
|
from OAI.types.chat_completion import (
|
||||||
ChatCompletionMessage,
|
ChatCompletionMessage,
|
||||||
ChatCompletionRespChoice,
|
ChatCompletionRespChoice,
|
||||||
@@ -14,8 +15,6 @@ from OAI.types.common import UsageStats
|
|||||||
from OAI.types.lora import LoraList, LoraCard
|
from OAI.types.lora import LoraList, LoraCard
|
||||||
from OAI.types.model import ModelList, ModelCard
|
from OAI.types.model import ModelList, ModelCard
|
||||||
|
|
||||||
from utils import unwrap
|
|
||||||
|
|
||||||
|
|
||||||
def create_completion_response(
|
def create_completion_response(
|
||||||
text: str,
|
text: str,
|
||||||
|
|||||||
@@ -13,17 +13,17 @@ from exllamav2 import (
|
|||||||
ExLlamaV2Lora,
|
ExLlamaV2Lora,
|
||||||
)
|
)
|
||||||
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
|
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
|
||||||
|
|
||||||
from gen_logging import log_generation_params, log_prompt, log_response
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
from templating import (
|
|
||||||
|
from common.gen_logging import log_generation_params, log_prompt, log_response
|
||||||
|
from common.templating import (
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
find_template_from_model,
|
find_template_from_model,
|
||||||
get_template_from_model_json,
|
get_template_from_model_json,
|
||||||
get_template_from_file,
|
get_template_from_file,
|
||||||
)
|
)
|
||||||
from utils import coalesce, unwrap
|
from common.utils import coalesce, unwrap
|
||||||
from logger import init_logger
|
from common.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ logger = init_logger(__name__)
|
|||||||
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
|
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
|
||||||
|
|
||||||
|
|
||||||
class ModelContainer:
|
class ExllamaV2Container:
|
||||||
"""The model container class for ExLlamaV2 models."""
|
"""The model container class for ExLlamaV2 models."""
|
||||||
|
|
||||||
config: Optional[ExLlamaV2Config] = None
|
config: Optional[ExLlamaV2Config] = None
|
||||||
@@ -3,13 +3,12 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local
|
|||||||
application, it should be fine.
|
application, it should be fine.
|
||||||
"""
|
"""
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Optional
|
import yaml
|
||||||
|
|
||||||
from fastapi import Header, HTTPException
|
from fastapi import Header, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import yaml
|
from typing import Optional
|
||||||
|
|
||||||
from logger import init_logger
|
from common.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
import yaml
|
import yaml
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from logger import init_logger
|
from common.logger import init_logger
|
||||||
from utils import unwrap
|
from common.utils import unwrap
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -4,7 +4,7 @@ Functions for logging generation events.
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from logger import init_logger
|
from common.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from logger import init_logger
|
from common.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -2,7 +2,8 @@ version: '3.8'
|
|||||||
services:
|
services:
|
||||||
tabbyapi:
|
tabbyapi:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: ..
|
||||||
|
dockerfile: ./docker/Dockerfile
|
||||||
ports:
|
ports:
|
||||||
- "5000:5000"
|
- "5000:5000"
|
||||||
environment:
|
environment:
|
||||||
27
main.py
27
main.py
@@ -11,10 +11,11 @@ from fastapi.responses import StreamingResponse
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from progress.bar import IncrementalBar
|
from progress.bar import IncrementalBar
|
||||||
|
|
||||||
import gen_logging
|
import common.gen_logging as gen_logging
|
||||||
from args import convert_args_to_dict, init_argparser
|
from backends.exllamav2.model import ExllamaV2Container
|
||||||
from auth import check_admin_key, check_api_key, load_auth_keys
|
from common.args import convert_args_to_dict, init_argparser
|
||||||
from config import (
|
from common.auth import check_admin_key, check_api_key, load_auth_keys
|
||||||
|
from common.config import (
|
||||||
override_config_from_args,
|
override_config_from_args,
|
||||||
read_config_from_file,
|
read_config_from_file,
|
||||||
get_gen_logging_config,
|
get_gen_logging_config,
|
||||||
@@ -23,8 +24,10 @@ from config import (
|
|||||||
get_lora_config,
|
get_lora_config,
|
||||||
get_network_config,
|
get_network_config,
|
||||||
)
|
)
|
||||||
from generators import call_with_semaphore, generate_with_semaphore
|
from common.generators import call_with_semaphore, generate_with_semaphore
|
||||||
from model import ModelContainer
|
from common.templating import get_all_templates, get_prompt_from_template
|
||||||
|
from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||||
|
from common.logger import init_logger
|
||||||
from OAI.types.completion import CompletionRequest
|
from OAI.types.completion import CompletionRequest
|
||||||
from OAI.types.chat_completion import ChatCompletionRequest
|
from OAI.types.chat_completion import ChatCompletionRequest
|
||||||
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
|
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
|
||||||
@@ -48,9 +51,6 @@ from OAI.utils_oai import (
|
|||||||
create_chat_completion_response,
|
create_chat_completion_response,
|
||||||
create_chat_completion_stream_chunk,
|
create_chat_completion_stream_chunk,
|
||||||
)
|
)
|
||||||
from templating import get_all_templates, get_prompt_from_template
|
|
||||||
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
|
||||||
from logger import init_logger
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ app = FastAPI(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Globally scoped variables. Undefined until initalized in main
|
# Globally scoped variables. Undefined until initalized in main
|
||||||
MODEL_CONTAINER: Optional[ModelContainer] = None
|
MODEL_CONTAINER: Optional[ExllamaV2Container] = None
|
||||||
|
|
||||||
|
|
||||||
def _check_model_container():
|
def _check_model_container():
|
||||||
@@ -182,7 +182,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||||
|
|
||||||
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **load_data)
|
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
|
||||||
|
|
||||||
async def generator():
|
async def generator():
|
||||||
"""Generator for the loading process."""
|
"""Generator for the loading process."""
|
||||||
@@ -530,7 +530,9 @@ def entrypoint(args: Optional[dict] = None):
|
|||||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||||
model_path = model_path / model_name
|
model_path = model_path / model_name
|
||||||
|
|
||||||
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **model_config)
|
MODEL_CONTAINER = ExllamaV2Container(
|
||||||
|
model_path.resolve(), False, **model_config
|
||||||
|
)
|
||||||
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
||||||
for module, modules in load_status:
|
for module, modules in load_status:
|
||||||
if module == 0:
|
if module == 0:
|
||||||
@@ -550,6 +552,7 @@ def entrypoint(args: Optional[dict] = None):
|
|||||||
host = unwrap(network_config.get("host"), "127.0.0.1")
|
host = unwrap(network_config.get("host"), "127.0.0.1")
|
||||||
port = unwrap(network_config.get("port"), 5000)
|
port = unwrap(network_config.get("port"), 5000)
|
||||||
|
|
||||||
|
# TODO: Move OAI API to a separate folder
|
||||||
logger.info(f"Developer documentation: http://{host}:{port}/docs")
|
logger.info(f"Developer documentation: http://{host}:{port}/docs")
|
||||||
logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
||||||
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
||||||
|
|||||||
2
start.py
2
start.py
@@ -3,7 +3,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import subprocess
|
import subprocess
|
||||||
from args import convert_args_to_dict, init_argparser
|
from common.args import convert_args_to_dict, init_argparser
|
||||||
|
|
||||||
|
|
||||||
def get_requirements_file():
|
def get_requirements_file():
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
""" Test the model container. """
|
""" Test the model container. """
|
||||||
from model import ModelContainer
|
from backends.exllamav2.model import ModelContainer
|
||||||
|
|
||||||
|
|
||||||
def progress(module, modules):
|
def progress(module, modules):
|
||||||
|
|||||||
Reference in New Issue
Block a user