diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 78f8f95..bcec0a1 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -56,18 +56,24 @@ class ADetailerArgs: } -def adetailer_wrapper(func): +def with_gc(func): def wrapper(*args, **kwargs): devices.torch_gc() - torch.load = unsafe_torch_load result = func(*args, **kwargs) devices.torch_gc() - torch.load = load return result return wrapper +class ChangeTorchLoad: + def __enter__(self): + torch.load = unsafe_torch_load + + def __exit__(self, *args, **kwargs): + torch.load = load + + def gr_show(visible=True): return {"visible": visible, "__type__": "update"} @@ -124,7 +130,7 @@ class AfterDetailerScript(scripts.Script): minimum=-128, maximum=128, step=4, - value=36, + value=32, visible=True, ) @@ -270,7 +276,7 @@ class AfterDetailerScript(scripts.Script): def get_args(*args): return ADetailerArgs(*args) - @adetailer_wrapper + @with_gc def postprocess_image(self, p, pp, *args_): if getattr(p, "_disable_adetailer", False): return @@ -344,7 +350,9 @@ class AfterDetailerScript(scripts.Script): predictor = ultralytics_predict ad_model = model_mapping[args.ad_model] - pred = predictor(ad_model, pp.image, args.ad_conf) + with ChangeTorchLoad(): + pred = predictor(ad_model, pp.image, args.ad_conf) + if pred.masks is None: print("ADetailer: nothing detected with current settings") return