mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-03-13 09:20:09 +00:00
feat: pause total tqdm
This commit is contained in:
@@ -5,6 +5,7 @@ import platform
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
@@ -62,13 +63,24 @@ print(
|
||||
)
|
||||
|
||||
|
||||
class ChangeTorchLoad:
|
||||
def __enter__(self):
|
||||
self.orig = torch.load
|
||||
@contextmanager
|
||||
def change_torch_load():
|
||||
orig = torch.load
|
||||
try:
|
||||
torch.load = safe.unsafe_torch_load
|
||||
yield
|
||||
finally:
|
||||
torch.load = orig
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
torch.load = self.orig
|
||||
|
||||
@contextmanager
|
||||
def pause_total_tqdm():
|
||||
orig = opts.data.get("multiple_tqdm", True)
|
||||
try:
|
||||
opts.data["multiple_tqdm"] = False
|
||||
yield
|
||||
finally:
|
||||
opts.data["multiple_tqdm"] = orig
|
||||
|
||||
|
||||
class AfterDetailerScript(scripts.Script):
|
||||
@@ -449,7 +461,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
ad_model = self.get_ad_model(args.ad_model)
|
||||
kwargs["device"] = self.ultralytics_device
|
||||
|
||||
with ChangeTorchLoad():
|
||||
with change_torch_load():
|
||||
pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs)
|
||||
|
||||
masks = self.pred_preprocessing(pred, args)
|
||||
@@ -510,7 +522,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
for n, args in enumerate(arg_list):
|
||||
if args.ad_model == "None":
|
||||
continue
|
||||
with CNHijackRestore():
|
||||
with CNHijackRestore(), pause_total_tqdm():
|
||||
is_processed |= self._postprocess_image(p, pp, args, n=n)
|
||||
|
||||
if is_processed:
|
||||
|
||||
Reference in New Issue
Block a user