feat: pause total tqdm

This commit is contained in:
Bingsu
2023-05-26 00:49:45 +09:00
parent 29697be303
commit 14cb75d503

View File

@@ -5,6 +5,7 @@ import platform
import re
import sys
import traceback
from contextlib import contextmanager
from copy import copy, deepcopy
from pathlib import Path
from textwrap import dedent
@@ -62,13 +63,24 @@ print(
)
class ChangeTorchLoad:
def __enter__(self):
self.orig = torch.load
@contextmanager
def change_torch_load():
orig = torch.load
try:
torch.load = safe.unsafe_torch_load
yield
finally:
torch.load = orig
def __exit__(self, *args, **kwargs):
torch.load = self.orig
@contextmanager
def pause_total_tqdm():
orig = opts.data.get("multiple_tqdm", True)
try:
opts.data["multiple_tqdm"] = False
yield
finally:
opts.data["multiple_tqdm"] = orig
class AfterDetailerScript(scripts.Script):
@@ -449,7 +461,7 @@ class AfterDetailerScript(scripts.Script):
ad_model = self.get_ad_model(args.ad_model)
kwargs["device"] = self.ultralytics_device
with ChangeTorchLoad():
with change_torch_load():
pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs)
masks = self.pred_preprocessing(pred, args)
@@ -510,7 +522,7 @@ class AfterDetailerScript(scripts.Script):
for n, args in enumerate(arg_list):
if args.ad_model == "None":
continue
with CNHijackRestore():
with CNHijackRestore(), pause_total_tqdm():
is_processed |= self._postprocess_image(p, pp, args, n=n)
if is_processed: