diff --git a/main.py b/main.py index 0d606fc..bff7691 100644 --- a/main.py +++ b/main.py @@ -244,7 +244,7 @@ async def load_model(request: Request, data: ModelLoadRequest): if module == 0: loading_task = progress.add_task( - "[cyan]Loading modules", total=modules + f"[cyan]Loading {model_type} modules", total=modules ) else: progress.advance(loading_task) @@ -259,8 +259,6 @@ async def load_model(request: Request, data: ModelLoadRequest): yield get_sse_packet(response.model_dump_json()) if module == modules: - progress.stop() - response = ModelLoadResponse( model_type=model_type, module=module, @@ -271,8 +269,10 @@ async def load_model(request: Request, data: ModelLoadRequest): yield get_sse_packet(response.model_dump_json()) # Switch to model progress if the draft model is loaded - if MODEL_CONTAINER.draft_config: + if model_type == "draft": model_type = "model" + else: + progress.stop() except CancelledError: logger.error( @@ -760,14 +760,21 @@ def entrypoint(args: Optional[dict] = None): progress = get_loading_progress_bar() progress.start() + model_type = "draft" if MODEL_CONTAINER.draft_config else "model" + for module, modules in load_status: if module == 0: - loading_task = progress.add_task("[cyan]Loading modules", total=modules) + loading_task = progress.add_task( + f"[cyan]Loading {model_type} modules", total=modules + ) else: - progress.advance(loading_task) + progress.advance(loading_task, 1) if module == modules: - progress.stop() + if model_type == "draft": + model_type = "model" + else: + progress.stop() # Load loras after loading the model lora_config = get_lora_config()