Tree: Refactor code organization

Move common functions into their own folder and refactor the backends
to use their own folder as well.

Also cleanup imports and alphabetize import statments themselves.

Finally, move colab and docker into their own folders as well.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-01-18 00:42:52 -05:00
committed by Brian Dashore
parent ee99349a78
commit 78f920eeda
22 changed files with 41 additions and 42 deletions

27
main.py
View File

@@ -11,10 +11,11 @@ from fastapi.responses import StreamingResponse
from functools import partial
from progress.bar import IncrementalBar
import gen_logging
from args import convert_args_to_dict, init_argparser
from auth import check_admin_key, check_api_key, load_auth_keys
from config import (
import common.gen_logging as gen_logging
from backends.exllamav2.model import ExllamaV2Container
from common.args import convert_args_to_dict, init_argparser
from common.auth import check_admin_key, check_api_key, load_auth_keys
from common.config import (
override_config_from_args,
read_config_from_file,
get_gen_logging_config,
@@ -23,8 +24,10 @@ from config import (
get_lora_config,
get_network_config,
)
from generators import call_with_semaphore, generate_with_semaphore
from model import ModelContainer
from common.generators import call_with_semaphore, generate_with_semaphore
from common.templating import get_all_templates, get_prompt_from_template
from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap
from common.logger import init_logger
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
@@ -48,9 +51,6 @@ from OAI.utils_oai import (
create_chat_completion_response,
create_chat_completion_stream_chunk,
)
from templating import get_all_templates, get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
from logger import init_logger
logger = init_logger(__name__)
@@ -64,7 +64,7 @@ app = FastAPI(
)
# Globally scoped variables. Undefined until initalized in main
MODEL_CONTAINER: Optional[ModelContainer] = None
MODEL_CONTAINER: Optional[ExllamaV2Container] = None
def _check_model_container():
@@ -182,7 +182,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **load_data)
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
async def generator():
"""Generator for the loading process."""
@@ -530,7 +530,9 @@ def entrypoint(args: Optional[dict] = None):
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / model_name
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **model_config)
MODEL_CONTAINER = ExllamaV2Container(
model_path.resolve(), False, **model_config
)
load_status = MODEL_CONTAINER.load_gen(load_progress)
for module, modules in load_status:
if module == 0:
@@ -550,6 +552,7 @@ def entrypoint(args: Optional[dict] = None):
host = unwrap(network_config.get("host"), "127.0.0.1")
port = unwrap(network_config.get("port"), 5000)
# TODO: Move OAI API to a separate folder
logger.info(f"Developer documentation: http://{host}:{port}/docs")
logger.info(f"Completions: http://{host}:{port}/v1/completions")
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")