mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-25 16:59:09 +00:00
2
auth.py
2
auth.py
@@ -37,7 +37,7 @@ def load_auth_keys():
|
|||||||
api_key = auth_keys_dict["api_key"],
|
api_key = auth_keys_dict["api_key"],
|
||||||
admin_key = auth_keys_dict["admin_key"]
|
admin_key = auth_keys_dict["admin_key"]
|
||||||
)
|
)
|
||||||
except:
|
except Exception as _:
|
||||||
new_auth_keys = AuthKeys(
|
new_auth_keys = AuthKeys(
|
||||||
api_key = secrets.token_hex(16),
|
api_key = secrets.token_hex(16),
|
||||||
admin_key = secrets.token_hex(16)
|
admin_key = secrets.token_hex(16)
|
||||||
|
|||||||
11
main.py
11
main.py
@@ -1,15 +1,18 @@
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
import yaml
|
import yaml
|
||||||
import pathlib
|
import pathlib
|
||||||
import gen_logging
|
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from auth import check_admin_key, check_api_key, load_auth_keys
|
|
||||||
from fastapi import FastAPI, Request, HTTPException, Depends
|
from fastapi import FastAPI, Request, HTTPException, Depends
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from model import ModelContainer
|
|
||||||
from progress.bar import IncrementalBar
|
from progress.bar import IncrementalBar
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import gen_logging
|
||||||
|
from auth import check_admin_key, check_api_key, load_auth_keys
|
||||||
from generators import generate_with_semaphore
|
from generators import generate_with_semaphore
|
||||||
|
from model import ModelContainer
|
||||||
from OAI.types.completion import CompletionRequest
|
from OAI.types.completion import CompletionRequest
|
||||||
from OAI.types.chat_completion import ChatCompletionRequest
|
from OAI.types.chat_completion import ChatCompletionRequest
|
||||||
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
|
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
|
||||||
@@ -28,9 +31,7 @@ from OAI.utils import (
|
|||||||
create_chat_completion_response,
|
create_chat_completion_response,
|
||||||
create_chat_completion_stream_chunk
|
create_chat_completion_stream_chunk
|
||||||
)
|
)
|
||||||
from typing import Optional
|
|
||||||
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|||||||
21
model.py
21
model.py
@@ -1,4 +1,6 @@
|
|||||||
import gc, time, pathlib
|
import gc
|
||||||
|
import pathlib
|
||||||
|
import time
|
||||||
import torch
|
import torch
|
||||||
from exllamav2 import(
|
from exllamav2 import(
|
||||||
ExLlamaV2,
|
ExLlamaV2,
|
||||||
@@ -12,9 +14,10 @@ from exllamav2.generator import(
|
|||||||
ExLlamaV2StreamingGenerator,
|
ExLlamaV2StreamingGenerator,
|
||||||
ExLlamaV2Sampler
|
ExLlamaV2Sampler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from gen_logging import log_generation_params, log_prompt, log_response
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
from utils import coalesce, unwrap
|
from utils import coalesce, unwrap
|
||||||
from gen_logging import log_generation_params, log_prompt, log_response
|
|
||||||
|
|
||||||
# Bytes to reserve on first device when loading with auto split
|
# Bytes to reserve on first device when loading with auto split
|
||||||
auto_split_reserve_bytes = 96 * 1024**2
|
auto_split_reserve_bytes = 96 * 1024**2
|
||||||
@@ -147,7 +150,8 @@ class ModelContainer:
|
|||||||
progress_callback (function, optional): A function to call for each module loaded. Prototype:
|
progress_callback (function, optional): A function to call for each module loaded. Prototype:
|
||||||
def progress(loaded_modules: int, total_modules: int)
|
def progress(loaded_modules: int, total_modules: int)
|
||||||
"""
|
"""
|
||||||
for _ in self.load_gen(progress_callback): pass
|
for _ in self.load_gen(progress_callback):
|
||||||
|
pass
|
||||||
|
|
||||||
def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -243,10 +247,14 @@ class ModelContainer:
|
|||||||
|
|
||||||
# Unload the entire model if not just unloading loras
|
# Unload the entire model if not just unloading loras
|
||||||
if not loras_only:
|
if not loras_only:
|
||||||
if self.model: self.model.unload()
|
if self.model:
|
||||||
|
self.model.unload()
|
||||||
self.model = None
|
self.model = None
|
||||||
if self.draft_model: self.draft_model.unload()
|
|
||||||
|
if self.draft_model:
|
||||||
|
self.draft_model.unload()
|
||||||
self.draft_model = None
|
self.draft_model = None
|
||||||
|
|
||||||
self.config = None
|
self.config = None
|
||||||
self.cache = None
|
self.cache = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
@@ -440,7 +448,8 @@ class ModelContainer:
|
|||||||
chunk_buffer = ""
|
chunk_buffer = ""
|
||||||
last_chunk_time = now
|
last_chunk_time = now
|
||||||
|
|
||||||
if eos or generated_tokens == max_tokens: break
|
if eos or generated_tokens == max_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
# Print response
|
# Print response
|
||||||
log_response(full_response)
|
log_response(full_response)
|
||||||
|
|||||||
Reference in New Issue
Block a user