From 14cb75d5035dbbf806bb8ad9590e3556aff09f60 Mon Sep 17 00:00:00 2001 From: Bingsu Date: Fri, 26 May 2023 00:49:45 +0900 Subject: [PATCH] feat: pause total tqdm --- scripts/!adetailer.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index d5dda67..a8bf08f 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -5,6 +5,7 @@ import platform import re import sys import traceback +from contextlib import contextmanager from copy import copy, deepcopy from pathlib import Path from textwrap import dedent @@ -62,13 +63,24 @@ print( ) -class ChangeTorchLoad: - def __enter__(self): - self.orig = torch.load +@contextmanager +def change_torch_load(): + orig = torch.load + try: torch.load = safe.unsafe_torch_load + yield + finally: + torch.load = orig - def __exit__(self, *args, **kwargs): - torch.load = self.orig + +@contextmanager +def pause_total_tqdm(): + orig = opts.data.get("multiple_tqdm", True) + try: + opts.data["multiple_tqdm"] = False + yield + finally: + opts.data["multiple_tqdm"] = orig class AfterDetailerScript(scripts.Script): @@ -449,7 +461,7 @@ class AfterDetailerScript(scripts.Script): ad_model = self.get_ad_model(args.ad_model) kwargs["device"] = self.ultralytics_device - with ChangeTorchLoad(): + with change_torch_load(): pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs) masks = self.pred_preprocessing(pred, args) @@ -510,7 +522,7 @@ class AfterDetailerScript(scripts.Script): for n, args in enumerate(arg_list): if args.ad_model == "None": continue - with CNHijackRestore(): + with CNHijackRestore(), pause_total_tqdm(): is_processed |= self._postprocess_image(p, pp, args, n=n) if is_processed: