fix: params.txt, misc

This commit is contained in:
Bingsu
2023-04-27 08:27:50 +09:00
parent f97bf1311d
commit 666e6b4409

View File

@@ -8,10 +8,9 @@ import torch
import modules
from adetailer import __version__, get_models, mediapipe_predict, ultralytics_predict
from adetailer.common import dilate_erode, is_all_black, offset
from modules import devices, images, script_callbacks, scripts, shared
from modules import devices, images, safe, script_callbacks, scripts, shared
from modules.paths import data_path, models_path
from modules.processing import StableDiffusionProcessingImg2Img, process_images
from modules.safe import load, unsafe_torch_load
from modules.shared import opts, state
AFTER_DETAILER = "After Detailer"
@@ -56,11 +55,21 @@ class ADetailerArgs:
}
def with_gc(func):
def with_params_txt(func):
params_txt = Path(data_path, "params.txt")
original_params = ""
def wrapper(*args, **kwargs):
devices.torch_gc()
nonlocal original_params
if not original_params and params_txt.exists():
original_params = params_txt.read_text(encoding="utf-8")
result = func(*args, **kwargs)
devices.torch_gc()
if original_params:
params_txt.write_text(original_params, encoding="utf-8")
return result
return wrapper
@@ -68,10 +77,11 @@ def with_gc(func):
class ChangeTorchLoad:
def __enter__(self):
torch.load = unsafe_torch_load
self.orig = torch.load
torch.load = safe.unsafe_torch_load
def __exit__(self, *args, **kwargs):
torch.load = load
torch.load = self.orig
def gr_show(visible=True):
@@ -276,7 +286,7 @@ class AfterDetailerScript(scripts.Script):
def get_args(*args):
return ADetailerArgs(*args)
@with_gc
@with_params_txt
def postprocess_image(self, p, pp, *args_):
if getattr(p, "_disable_adetailer", False):
return
@@ -305,12 +315,6 @@ class AfterDetailerScript(scripts.Script):
if sampler_name in ["PLMS", "UniPC"]:
sampler_name = "Euler"
params_txt = Path(data_path, "params.txt")
original_params = ""
if params_txt.exists():
with params_txt.open("r", encoding="utf-8") as f:
original_params = f.read()
i2i = StableDiffusionProcessingImg2Img(
init_images=[pp.image],
resize_mode=0,
@@ -359,14 +363,12 @@ class AfterDetailerScript(scripts.Script):
with ChangeTorchLoad():
pred = predictor(ad_model, pp.image, args.ad_conf)
devices.torch_gc()
if pred.masks is None:
print("ADetailer: nothing detected with current settings")
return
masks = pred.masks
steps = len(masks)
processed = None
if opts.data.get("ad_save_previews", False):
images.save_image(
pred.preview,
@@ -379,13 +381,18 @@ class AfterDetailerScript(scripts.Script):
suffix="-ad-preview",
)
masks = pred.masks
steps = len(masks)
processed = None
for j in range(steps):
mask = masks[j]
mask = dilate_erode(mask, args.ad_dilate_erode)
if is_all_black(mask):
continue
mask = offset(mask, args.ad_x_offset, args.ad_y_offset)
mask = offset(mask, args.ad_x_offset, args.ad_y_offset)
i2i.image_mask = mask
processed = process_images(i2i)
@@ -396,10 +403,6 @@ class AfterDetailerScript(scripts.Script):
if processed is not None:
pp.image = processed.images[0]
if original_params:
with params_txt.open("w", encoding="utf-8") as f:
f.write(original_params)
def on_ui_settings():
section = ("ADetailer", AFTER_DETAILER)