Progress: Fix bar with draft models

Show two bars and clarify which bar is which.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-08 01:48:06 -05:00
parent c9b4b7c509
commit 2295b12643

21
main.py
View File

@@ -244,7 +244,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
if module == 0:
loading_task = progress.add_task(
"[cyan]Loading modules", total=modules
f"[cyan]Loading {model_type} modules", total=modules
)
else:
progress.advance(loading_task)
@@ -259,8 +259,6 @@ async def load_model(request: Request, data: ModelLoadRequest):
yield get_sse_packet(response.model_dump_json())
if module == modules:
progress.stop()
response = ModelLoadResponse(
model_type=model_type,
module=module,
@@ -271,8 +269,10 @@ async def load_model(request: Request, data: ModelLoadRequest):
yield get_sse_packet(response.model_dump_json())
# Switch to model progress if the draft model is loaded
if MODEL_CONTAINER.draft_config:
if model_type == "draft":
model_type = "model"
else:
progress.stop()
except CancelledError:
logger.error(
@@ -760,14 +760,21 @@ def entrypoint(args: Optional[dict] = None):
progress = get_loading_progress_bar()
progress.start()
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
for module, modules in load_status:
if module == 0:
loading_task = progress.add_task("[cyan]Loading modules", total=modules)
loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
)
else:
progress.advance(loading_task)
progress.advance(loading_task, 1)
if module == modules:
progress.stop()
if model_type == "draft":
model_type = "model"
else:
progress.stop()
# Load loras after loading the model
lora_config = get_lora_config()