mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 02:01:24 +00:00
API: Add draft model support
Models can be loaded with a child object called "draft" in the POST request. Again, models need to be located within the draft model dir to get loaded. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -12,6 +12,10 @@ class ModelList(BaseModel):
|
|||||||
object: str = "list"
|
object: str = "list"
|
||||||
data: List[ModelCard] = Field(default_factory=list)
|
data: List[ModelCard] = Field(default_factory=list)
|
||||||
|
|
||||||
|
class DraftModelLoadRequest(BaseModel):
|
||||||
|
draft_model_name: str
|
||||||
|
draft_rope_alpha: float = 1.0
|
||||||
|
|
||||||
class ModelLoadRequest(BaseModel):
|
class ModelLoadRequest(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
max_seq_len: Optional[int] = 4096
|
max_seq_len: Optional[int] = 4096
|
||||||
@@ -21,8 +25,10 @@ class ModelLoadRequest(BaseModel):
|
|||||||
rope_alpha: Optional[float] = 1.0
|
rope_alpha: Optional[float] = 1.0
|
||||||
no_flash_attention: Optional[bool] = False
|
no_flash_attention: Optional[bool] = False
|
||||||
low_mem: Optional[bool] = False
|
low_mem: Optional[bool] = False
|
||||||
|
draft: Optional[DraftModelLoadRequest] = None
|
||||||
|
|
||||||
class ModelLoadResponse(BaseModel):
|
class ModelLoadResponse(BaseModel):
|
||||||
|
model_type: str = "model"
|
||||||
module: int
|
module: int
|
||||||
modules: int
|
modules: int
|
||||||
status: str
|
status: str
|
||||||
|
|||||||
47
main.py
47
main.py
@@ -73,18 +73,28 @@ async def load_model(data: ModelLoadRequest):
|
|||||||
if model_container and model_container.model:
|
if model_container and model_container.model:
|
||||||
raise HTTPException(400, "A model is already loaded! Please unload it first.")
|
raise HTTPException(400, "A model is already loaded! Please unload it first.")
|
||||||
|
|
||||||
|
if not data.name:
|
||||||
|
raise HTTPException(400, "model_name not found.")
|
||||||
|
|
||||||
|
model_config = config.get("model") or {}
|
||||||
|
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||||
|
model_path = model_path / data.name
|
||||||
|
|
||||||
|
load_data = data.dict()
|
||||||
|
if data.draft and "draft" in model_config:
|
||||||
|
draft_config = model_config.get("draft") or {}
|
||||||
|
|
||||||
|
if not data.draft.draft_model_name:
|
||||||
|
raise HTTPException(400, "draft_model_name was not found inside the draft object.")
|
||||||
|
|
||||||
|
load_data["draft_model_dir"] = draft_config.get("draft_model_dir") or "models"
|
||||||
|
|
||||||
def generator():
|
def generator():
|
||||||
global model_container
|
global model_container
|
||||||
|
|
||||||
model_config = config.get("model") or {}
|
model_container = ModelContainer(model_path.resolve(), False, **load_data)
|
||||||
if "model_dir" in model_config:
|
model_type = "draft" if model_container.draft_enabled else "model"
|
||||||
model_path = pathlib.Path(model_config["model_dir"])
|
|
||||||
else:
|
|
||||||
model_path = pathlib.Path("models")
|
|
||||||
|
|
||||||
model_path = model_path / data.name
|
|
||||||
|
|
||||||
model_container = ModelContainer(model_path.resolve(), False, **data.dict())
|
|
||||||
load_status = model_container.load_gen(load_progress)
|
load_status = model_container.load_gen(load_progress)
|
||||||
for (module, modules) in load_status:
|
for (module, modules) in load_status:
|
||||||
if module == 0:
|
if module == 0:
|
||||||
@@ -92,10 +102,23 @@ async def load_model(data: ModelLoadRequest):
|
|||||||
elif module == modules:
|
elif module == modules:
|
||||||
loading_bar.next()
|
loading_bar.next()
|
||||||
loading_bar.finish()
|
loading_bar.finish()
|
||||||
|
|
||||||
|
response = ModelLoadResponse(
|
||||||
|
model_type=model_type,
|
||||||
|
module=module,
|
||||||
|
modules=modules,
|
||||||
|
status="finished"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield response.json(ensure_ascii=False)
|
||||||
|
|
||||||
|
if model_container.draft_enabled:
|
||||||
|
model_type = "model"
|
||||||
else:
|
else:
|
||||||
loading_bar.next()
|
loading_bar.next()
|
||||||
|
|
||||||
response = ModelLoadResponse(
|
response = ModelLoadResponse(
|
||||||
|
model_type=model_type,
|
||||||
module=module,
|
module=module,
|
||||||
modules=modules,
|
modules=modules,
|
||||||
status="processing"
|
status="processing"
|
||||||
@@ -103,14 +126,6 @@ async def load_model(data: ModelLoadRequest):
|
|||||||
|
|
||||||
yield response.json(ensure_ascii=False)
|
yield response.json(ensure_ascii=False)
|
||||||
|
|
||||||
response = ModelLoadResponse(
|
|
||||||
module=module,
|
|
||||||
modules=modules,
|
|
||||||
status="finished"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield response.json(ensure_ascii=False)
|
|
||||||
|
|
||||||
return EventSourceResponse(generator())
|
return EventSourceResponse(generator())
|
||||||
|
|
||||||
# Unload model endpoint
|
# Unload model endpoint
|
||||||
|
|||||||
Reference in New Issue
Block a user