From 79d11f8b3fdef3a0f4538cb0dc915f28545b1e77 Mon Sep 17 00:00:00 2001 From: Bingsu Date: Wed, 26 Apr 2023 23:37:03 +0900 Subject: [PATCH] =?UTF-8?q?fix:=20change=20torch=20load,=20dilate=2036=20?= =?UTF-8?q?=E2=86=92=2032?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/!adetailer.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 78f8f95..bcec0a1 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -56,18 +56,24 @@ class ADetailerArgs: } -def adetailer_wrapper(func): +def with_gc(func): def wrapper(*args, **kwargs): devices.torch_gc() - torch.load = unsafe_torch_load result = func(*args, **kwargs) devices.torch_gc() - torch.load = load return result return wrapper +class ChangeTorchLoad: + def __enter__(self): + torch.load = unsafe_torch_load + + def __exit__(self, *args, **kwargs): + torch.load = load + + def gr_show(visible=True): return {"visible": visible, "__type__": "update"} @@ -124,7 +130,7 @@ class AfterDetailerScript(scripts.Script): minimum=-128, maximum=128, step=4, - value=36, + value=32, visible=True, ) @@ -270,7 +276,7 @@ class AfterDetailerScript(scripts.Script): def get_args(*args): return ADetailerArgs(*args) - @adetailer_wrapper + @with_gc def postprocess_image(self, p, pp, *args_): if getattr(p, "_disable_adetailer", False): return @@ -344,7 +350,9 @@ class AfterDetailerScript(scripts.Script): predictor = ultralytics_predict ad_model = model_mapping[args.ad_model] - pred = predictor(ad_model, pp.image, args.ad_conf) + with ChangeTorchLoad(): + pred = predictor(ad_model, pp.image, args.ad_conf) + if pred.masks is None: print("ADetailer: nothing detected with current settings") return