Progress: Switch to Rich

Rich is a more mature library for displaying progress bars, logging,
and console output. This should help properly align progress bars
within the terminal.

Side note: "We're Rich!"

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-06 00:37:31 -05:00
committed by Brian Dashore
parent 39617adb65
commit fe0ff240e7
6 changed files with 55 additions and 24 deletions

51
main.py
View File

@@ -15,7 +15,6 @@ from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from functools import partial
from progress.bar import IncrementalBar
import common.gen_logging as gen_logging
from backends.exllamav2.model import ExllamaV2Container
@@ -46,6 +45,7 @@ from common.templating import (
)
from common.utils import (
get_generator_error,
get_loading_progress_bar,
get_sse_packet,
handle_request_error,
load_progress,
@@ -233,6 +233,9 @@ async def load_model(request: Request, data: ModelLoadRequest):
load_status = MODEL_CONTAINER.load_gen(load_progress)
try:
progress = get_loading_progress_bar()
progress.start()
for module, modules in load_status:
if await request.is_disconnected():
logger.error(
@@ -242,10 +245,23 @@ async def load_model(request: Request, data: ModelLoadRequest):
break
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
loading_task = progress.add_task(
"[cyan]Loading modules", total=modules
)
else:
progress.advance(loading_task)
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing",
)
yield get_sse_packet(response.model_dump_json())
if module == modules:
progress.stop()
response = ModelLoadResponse(
model_type=model_type,
@@ -259,17 +275,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
# Switch to model progress if the draft model is loaded
if MODEL_CONTAINER.draft_config:
model_type = "model"
else:
loading_bar.next()
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing",
)
yield get_sse_packet(response.model_dump_json())
except CancelledError:
logger.error(
"Model load cancelled by user. "
@@ -277,6 +283,8 @@ async def load_model(request: Request, data: ModelLoadRequest):
)
except Exception as exc:
yield get_generator_error(str(exc))
finally:
progress.stop()
# Determine whether to use or skip the queue
if data.skip_queue:
@@ -749,14 +757,17 @@ def entrypoint(args: Optional[dict] = None):
model_path.resolve(), False, **model_config
)
load_status = MODEL_CONTAINER.load_gen(load_progress)
progress = get_loading_progress_bar()
progress.start()
for module, modules in load_status:
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
loading_task = progress.add_task("[cyan]Loading modules", total=modules)
else:
loading_bar.next()
progress.advance(loading_task)
if module == modules:
progress.stop()
# Load loras after loading the model
lora_config = get_lora_config()