Progress: Fix bar with draft models

Show two bars and clarify which bar is which.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-08 01:48:06 -05:00
parent c9b4b7c509
commit 2295b12643

21
main.py
View File

@@ -244,7 +244,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
if module == 0: if module == 0:
loading_task = progress.add_task( loading_task = progress.add_task(
"[cyan]Loading modules", total=modules f"[cyan]Loading {model_type} modules", total=modules
) )
else: else:
progress.advance(loading_task) progress.advance(loading_task)
@@ -259,8 +259,6 @@ async def load_model(request: Request, data: ModelLoadRequest):
yield get_sse_packet(response.model_dump_json()) yield get_sse_packet(response.model_dump_json())
if module == modules: if module == modules:
progress.stop()
response = ModelLoadResponse( response = ModelLoadResponse(
model_type=model_type, model_type=model_type,
module=module, module=module,
@@ -271,8 +269,10 @@ async def load_model(request: Request, data: ModelLoadRequest):
yield get_sse_packet(response.model_dump_json()) yield get_sse_packet(response.model_dump_json())
# Switch to model progress if the draft model is loaded # Switch to model progress if the draft model is loaded
if MODEL_CONTAINER.draft_config: if model_type == "draft":
model_type = "model" model_type = "model"
else:
progress.stop()
except CancelledError: except CancelledError:
logger.error( logger.error(
@@ -760,14 +760,21 @@ def entrypoint(args: Optional[dict] = None):
progress = get_loading_progress_bar() progress = get_loading_progress_bar()
progress.start() progress.start()
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
for module, modules in load_status: for module, modules in load_status:
if module == 0: if module == 0:
loading_task = progress.add_task("[cyan]Loading modules", total=modules) loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
)
else: else:
progress.advance(loading_task) progress.advance(loading_task, 1)
if module == modules: if module == modules:
progress.stop() if model_type == "draft":
model_type = "model"
else:
progress.stop()
# Load loras after loading the model # Load loras after loading the model
lora_config = get_lora_config() lora_config = get_lora_config()