mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-03-14 01:40:05 +00:00
fix: params.txt, misc
This commit is contained in:
@@ -8,10 +8,9 @@ import torch
|
||||
import modules
|
||||
from adetailer import __version__, get_models, mediapipe_predict, ultralytics_predict
|
||||
from adetailer.common import dilate_erode, is_all_black, offset
|
||||
from modules import devices, images, script_callbacks, scripts, shared
|
||||
from modules import devices, images, safe, script_callbacks, scripts, shared
|
||||
from modules.paths import data_path, models_path
|
||||
from modules.processing import StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.safe import load, unsafe_torch_load
|
||||
from modules.shared import opts, state
|
||||
|
||||
AFTER_DETAILER = "After Detailer"
|
||||
@@ -56,11 +55,21 @@ class ADetailerArgs:
|
||||
}
|
||||
|
||||
|
||||
def with_gc(func):
|
||||
def with_params_txt(func):
|
||||
params_txt = Path(data_path, "params.txt")
|
||||
original_params = ""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
devices.torch_gc()
|
||||
nonlocal original_params
|
||||
|
||||
if not original_params and params_txt.exists():
|
||||
original_params = params_txt.read_text(encoding="utf-8")
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
devices.torch_gc()
|
||||
|
||||
if original_params:
|
||||
params_txt.write_text(original_params, encoding="utf-8")
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -68,10 +77,11 @@ def with_gc(func):
|
||||
|
||||
class ChangeTorchLoad:
|
||||
def __enter__(self):
|
||||
torch.load = unsafe_torch_load
|
||||
self.orig = torch.load
|
||||
torch.load = safe.unsafe_torch_load
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
torch.load = load
|
||||
torch.load = self.orig
|
||||
|
||||
|
||||
def gr_show(visible=True):
|
||||
@@ -276,7 +286,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
def get_args(*args):
|
||||
return ADetailerArgs(*args)
|
||||
|
||||
@with_gc
|
||||
@with_params_txt
|
||||
def postprocess_image(self, p, pp, *args_):
|
||||
if getattr(p, "_disable_adetailer", False):
|
||||
return
|
||||
@@ -305,12 +315,6 @@ class AfterDetailerScript(scripts.Script):
|
||||
if sampler_name in ["PLMS", "UniPC"]:
|
||||
sampler_name = "Euler"
|
||||
|
||||
params_txt = Path(data_path, "params.txt")
|
||||
original_params = ""
|
||||
if params_txt.exists():
|
||||
with params_txt.open("r", encoding="utf-8") as f:
|
||||
original_params = f.read()
|
||||
|
||||
i2i = StableDiffusionProcessingImg2Img(
|
||||
init_images=[pp.image],
|
||||
resize_mode=0,
|
||||
@@ -359,14 +363,12 @@ class AfterDetailerScript(scripts.Script):
|
||||
with ChangeTorchLoad():
|
||||
pred = predictor(ad_model, pp.image, args.ad_conf)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
if pred.masks is None:
|
||||
print("ADetailer: nothing detected with current settings")
|
||||
return
|
||||
|
||||
masks = pred.masks
|
||||
steps = len(masks)
|
||||
processed = None
|
||||
|
||||
if opts.data.get("ad_save_previews", False):
|
||||
images.save_image(
|
||||
pred.preview,
|
||||
@@ -379,13 +381,18 @@ class AfterDetailerScript(scripts.Script):
|
||||
suffix="-ad-preview",
|
||||
)
|
||||
|
||||
masks = pred.masks
|
||||
steps = len(masks)
|
||||
processed = None
|
||||
|
||||
for j in range(steps):
|
||||
mask = masks[j]
|
||||
|
||||
mask = dilate_erode(mask, args.ad_dilate_erode)
|
||||
if is_all_black(mask):
|
||||
continue
|
||||
mask = offset(mask, args.ad_x_offset, args.ad_y_offset)
|
||||
|
||||
mask = offset(mask, args.ad_x_offset, args.ad_y_offset)
|
||||
i2i.image_mask = mask
|
||||
|
||||
processed = process_images(i2i)
|
||||
@@ -396,10 +403,6 @@ class AfterDetailerScript(scripts.Script):
|
||||
if processed is not None:
|
||||
pp.image = processed.images[0]
|
||||
|
||||
if original_params:
|
||||
with params_txt.open("w", encoding="utf-8") as f:
|
||||
f.write(original_params)
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
section = ("ADetailer", AFTER_DETAILER)
|
||||
|
||||
Reference in New Issue
Block a user