diff --git a/common/utils.py b/common/utils.py index 9f7b704..5f70757 100644 --- a/common/utils.py +++ b/common/utils.py @@ -2,6 +2,14 @@ import traceback from pydantic import BaseModel +from rich.progress import ( + Progress, + TextColumn, + BarColumn, + TimeRemainingColumn, + TaskProgressColumn, + MofNCompleteColumn, +) from typing import Optional from common.logger import init_logger @@ -58,6 +66,18 @@ def get_sse_packet(json_data: str): return f"data: {json_data}\n\n" +def get_loading_progress_bar(): + """Gets a pre-made progress bar for loading tasks.""" + + return Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + TimeRemainingColumn(), + ) + + def unwrap(wrapped, default=None): """Unwrap function for Optionals.""" if wrapped is None: diff --git a/main.py b/main.py index c866c82..0fec2cd 100644 --- a/main.py +++ b/main.py @@ -15,7 +15,6 @@ from fastapi.concurrency import run_in_threadpool from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from functools import partial -from progress.bar import IncrementalBar import common.gen_logging as gen_logging from backends.exllamav2.model import ExllamaV2Container @@ -46,6 +45,7 @@ from common.templating import ( ) from common.utils import ( get_generator_error, + get_loading_progress_bar, get_sse_packet, handle_request_error, load_progress, @@ -233,6 +233,9 @@ async def load_model(request: Request, data: ModelLoadRequest): load_status = MODEL_CONTAINER.load_gen(load_progress) try: + progress = get_loading_progress_bar() + progress.start() + for module, modules in load_status: if await request.is_disconnected(): logger.error( @@ -242,10 +245,23 @@ async def load_model(request: Request, data: ModelLoadRequest): break if module == 0: - loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules) - elif module == modules: - loading_bar.next() - loading_bar.finish() + loading_task = progress.add_task( + "[cyan]Loading modules", total=modules + ) + else: + progress.advance(loading_task) + + response = ModelLoadResponse( + model_type=model_type, + module=module, + modules=modules, + status="processing", + ) + + yield get_sse_packet(response.model_dump_json()) + + if module == modules: + progress.stop() response = ModelLoadResponse( model_type=model_type, @@ -259,17 +275,7 @@ async def load_model(request: Request, data: ModelLoadRequest): # Switch to model progress if the draft model is loaded if MODEL_CONTAINER.draft_config: model_type = "model" - else: - loading_bar.next() - response = ModelLoadResponse( - model_type=model_type, - module=module, - modules=modules, - status="processing", - ) - - yield get_sse_packet(response.model_dump_json()) except CancelledError: logger.error( "Model load cancelled by user. " @@ -277,6 +283,8 @@ async def load_model(request: Request, data: ModelLoadRequest): ) except Exception as exc: yield get_generator_error(str(exc)) + finally: + progress.stop() # Determine whether to use or skip the queue if data.skip_queue: @@ -749,14 +757,17 @@ def entrypoint(args: Optional[dict] = None): model_path.resolve(), False, **model_config ) load_status = MODEL_CONTAINER.load_gen(load_progress) + + progress = get_loading_progress_bar() + progress.start() for module, modules in load_status: if module == 0: - loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules) - elif module == modules: - loading_bar.next() - loading_bar.finish() + loading_task = progress.add_task("[cyan]Loading modules", total=modules) else: - loading_bar.next() + progress.advance(loading_task) + + if module == modules: + progress.stop() # Load loras after loading the model lora_config = get_lora_config() diff --git a/requirements-amd.txt b/requirements-amd.txt index 4de9ba4..ac8518a 100644 --- a/requirements-amd.txt +++ b/requirements-amd.txt @@ -10,7 +10,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.1 fastapi pydantic >= 2.0.0 PyYAML -progress +rich uvicorn jinja2 >= 3.0.0 colorlog diff --git a/requirements-cu118.txt b/requirements-cu118.txt index b67700b..2c46d16 100644 --- a/requirements-cu118.txt +++ b/requirements-cu118.txt @@ -16,7 +16,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.1 fastapi pydantic >= 2.0.0 PyYAML -progress +rich uvicorn jinja2 >= 3.0.0 colorlog diff --git a/requirements-nowheel.txt b/requirements-nowheel.txt index c36272e..d3b6721 100644 --- a/requirements-nowheel.txt +++ b/requirements-nowheel.txt @@ -2,7 +2,7 @@ fastapi pydantic >= 2.0.0 PyYAML -progress +rich uvicorn jinja2 >= 3.0.0 colorlog diff --git a/requirements.txt b/requirements.txt index 0d1e47c..b613116 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.1 fastapi pydantic >= 2.0.0 PyYAML -progress +rich uvicorn jinja2 >= 3.0.0 colorlog