mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-27 17:51:36 +00:00
Implement lora support (#24)
* Model: Implement basic lora support * Add ability to load loras from config on launch * Supports loading multiple loras and lora scaling * Add function to unload loras * Colab: Update for basic lora support * Model: Test vram alloc after lora load, add docs * Git: Add loras folder to .gitignore * API: Add basic lora-related endpoints * Add /loras/ endpoint for querying available loras * Add /model/lora endpoint for querying currently loaded loras * Add /model/lora/load endpoint for loading loras * Add /model/lora/unload endpoint for unloading loras * Move lora config-checking logic to main.py for better compat with API endpoints * Revert bad CRLF line ending changes * API: Add basic lora-related endpoints (fixed) * Add /loras/ endpoint for querying available loras * Add /model/lora endpoint for querying currently loaded loras * Add /model/lora/load endpoint for loading loras * Add /model/lora/unload endpoint for unloading loras * Move lora config-checking logic to main.py for better compat with API endpoints * Model: Unload loras first when unloading model * API + Models: Cleanup lora endpoints and functions Condenses down endpoint and model load code. Also makes the routes behave the same way as model routes to help not confuse the end user. Signed-off-by: kingbri <bdashore3@proton.me> * Loras: Optimize load endpoint Return successes and failures along with consolidating the request to the rewritten load_loras function. Signed-off-by: kingbri <bdashore3@proton.me> --------- Co-authored-by: kingbri <bdashore3@proton.me> Co-authored-by: DocShotgun <126566557+DocShotgun@users.noreply.github.com>
This commit is contained in:
67
main.py
67
main.py
@@ -11,6 +11,7 @@ from progress.bar import IncrementalBar
|
||||
from generators import generate_with_semaphore
|
||||
from OAI.types.completion import CompletionRequest
|
||||
from OAI.types.chat_completion import ChatCompletionRequest
|
||||
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
|
||||
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
|
||||
from OAI.types.token import (
|
||||
TokenEncodeRequest,
|
||||
@@ -21,6 +22,7 @@ from OAI.types.token import (
|
||||
from OAI.utils import (
|
||||
create_completion_response,
|
||||
get_model_list,
|
||||
get_lora_list,
|
||||
get_chat_completion_prompt,
|
||||
create_chat_completion_response,
|
||||
create_chat_completion_stream_chunk
|
||||
@@ -87,7 +89,6 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
||||
if not data.name:
|
||||
raise HTTPException(400, "model_name not found.")
|
||||
|
||||
# TODO: Move this to model_container
|
||||
model_config = config.get("model") or {}
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_path = model_path / data.name
|
||||
@@ -160,7 +161,63 @@ async def unload_model():
|
||||
global model_container
|
||||
|
||||
model_container.unload()
|
||||
model_container = None
|
||||
model_container = None
|
||||
|
||||
# Lora list endpoint
|
||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_all_loras():
|
||||
model_config = config.get("model") or {}
|
||||
lora_config = model_config.get("lora") or {}
|
||||
lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
|
||||
return loras
|
||||
|
||||
# Currently loaded loras endpoint
|
||||
@app.get("/v1/lora", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_active_loras():
|
||||
active_loras = LoraList(
|
||||
data = list(map(
|
||||
lambda lora: LoraCard(
|
||||
id = pathlib.Path(lora.lora_path).parent.name,
|
||||
scaling = lora.lora_scaling * lora.lora_r / lora.lora_alpha
|
||||
),
|
||||
model_container.active_loras
|
||||
)
|
||||
))
|
||||
|
||||
return active_loras
|
||||
|
||||
# Load lora endpoint
|
||||
@app.post("/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
|
||||
async def load_model(data: LoraLoadRequest):
|
||||
if not data.loras:
|
||||
raise HTTPException(400, "List of loras to load is not found.")
|
||||
|
||||
model_config = config.get("model") or {}
|
||||
lora_config = model_config.get("lora") or {}
|
||||
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
if not lora_dir.exists():
|
||||
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")
|
||||
|
||||
# Clean-up existing loras if present
|
||||
if len(model_container.active_loras) > 0:
|
||||
model_container.unload(True)
|
||||
|
||||
result = model_container.load_loras(lora_dir, **data.dict())
|
||||
return LoraLoadResponse(
|
||||
success = result.get("success") or [],
|
||||
failure = result.get("failure") or []
|
||||
)
|
||||
|
||||
# Unload lora endpoint
|
||||
@app.get("/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
|
||||
async def unload_loras():
|
||||
global model_container
|
||||
|
||||
model_container.unload(True)
|
||||
|
||||
# Encode tokens endpoint
|
||||
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
@@ -308,6 +365,12 @@ if __name__ == "__main__":
|
||||
else:
|
||||
loading_bar.next()
|
||||
|
||||
# Load loras
|
||||
lora_config = model_config.get("lora") or {}
|
||||
if "loras" in lora_config:
|
||||
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
model_container.load_loras(lora_dir.resolve(), **lora_config)
|
||||
|
||||
network_config = config.get("network") or {}
|
||||
uvicorn.run(
|
||||
app,
|
||||
|
||||
Reference in New Issue
Block a user