diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index f835e46..f2391df 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -8,10 +8,9 @@ import torch import modules from adetailer import __version__, get_models, mediapipe_predict, ultralytics_predict from adetailer.common import dilate_erode, is_all_black, offset -from modules import devices, images, script_callbacks, scripts, shared +from modules import devices, images, safe, script_callbacks, scripts, shared from modules.paths import data_path, models_path from modules.processing import StableDiffusionProcessingImg2Img, process_images -from modules.safe import load, unsafe_torch_load from modules.shared import opts, state AFTER_DETAILER = "After Detailer" @@ -56,11 +55,21 @@ class ADetailerArgs: } -def with_gc(func): +def with_params_txt(func): + params_txt = Path(data_path, "params.txt") + original_params = "" + def wrapper(*args, **kwargs): - devices.torch_gc() + nonlocal original_params + + if not original_params and params_txt.exists(): + original_params = params_txt.read_text(encoding="utf-8") + result = func(*args, **kwargs) - devices.torch_gc() + + if original_params: + params_txt.write_text(original_params, encoding="utf-8") + return result return wrapper @@ -68,10 +77,11 @@ def with_gc(func): class ChangeTorchLoad: def __enter__(self): - torch.load = unsafe_torch_load + self.orig = torch.load + torch.load = safe.unsafe_torch_load def __exit__(self, *args, **kwargs): - torch.load = load + torch.load = self.orig def gr_show(visible=True): @@ -276,7 +286,7 @@ class AfterDetailerScript(scripts.Script): def get_args(*args): return ADetailerArgs(*args) - @with_gc + @with_params_txt def postprocess_image(self, p, pp, *args_): if getattr(p, "_disable_adetailer", False): return @@ -305,12 +315,6 @@ class AfterDetailerScript(scripts.Script): if sampler_name in ["PLMS", "UniPC"]: sampler_name = "Euler" - params_txt = Path(data_path, "params.txt") - original_params = "" - if params_txt.exists(): - with params_txt.open("r", encoding="utf-8") as f: - original_params = f.read() - i2i = StableDiffusionProcessingImg2Img( init_images=[pp.image], resize_mode=0, @@ -359,14 +363,12 @@ class AfterDetailerScript(scripts.Script): with ChangeTorchLoad(): pred = predictor(ad_model, pp.image, args.ad_conf) + devices.torch_gc() + if pred.masks is None: print("ADetailer: nothing detected with current settings") return - masks = pred.masks - steps = len(masks) - processed = None - if opts.data.get("ad_save_previews", False): images.save_image( pred.preview, @@ -379,13 +381,18 @@ class AfterDetailerScript(scripts.Script): suffix="-ad-preview", ) + masks = pred.masks + steps = len(masks) + processed = None + for j in range(steps): mask = masks[j] + mask = dilate_erode(mask, args.ad_dilate_erode) if is_all_black(mask): continue - mask = offset(mask, args.ad_x_offset, args.ad_y_offset) + mask = offset(mask, args.ad_x_offset, args.ad_y_offset) i2i.image_mask = mask processed = process_images(i2i) @@ -396,10 +403,6 @@ class AfterDetailerScript(scripts.Script): if processed is not None: pp.image = processed.images[0] - if original_params: - with params_txt.open("w", encoding="utf-8") as f: - f.write(original_params) - def on_ui_settings(): section = ("ADetailer", AFTER_DETAILER)