mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-01-26 19:29:54 +00:00
refactor(script): reduce complexity
This commit is contained in:
@@ -5,6 +5,8 @@ from copy import copy
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from modules import safe
|
||||
from modules.shared import opts
|
||||
@@ -57,3 +59,7 @@ def preserve_prompts(p: PT):
|
||||
|
||||
def copy_extra_params(extra_params: dict[str, Any]) -> dict[str, Any]:
|
||||
return {k: v for k, v in extra_params.items() if not callable(v)}
|
||||
|
||||
|
||||
class PPImage(Protocol):
|
||||
image: Image.Image
|
||||
|
||||
@@ -18,6 +18,7 @@ from rich import print
|
||||
import modules
|
||||
from aaaaaa.conditional import create_binary_mask, schedulers
|
||||
from aaaaaa.helper import (
|
||||
PPImage,
|
||||
change_torch_load,
|
||||
copy_extra_params,
|
||||
pause_total_tqdm,
|
||||
@@ -744,6 +745,25 @@ class AfterDetailerScript(scripts.Script):
|
||||
|
||||
return optimal_resolution
|
||||
|
||||
def fix_p2(
|
||||
self, p, p2, pp: PPImage, args: ADetailerArgs, pred: PredictOutput, j: int
|
||||
):
|
||||
seed, subseed = self.get_seed(p)
|
||||
p2.seed = self.get_each_tab_seed(seed, j)
|
||||
p2.subseed = self.get_each_tab_seed(subseed, j)
|
||||
p2.denoising_strength = self.get_dynamic_denoise_strength(
|
||||
p2.denoising_strength, pred.bboxes[j], pp.image.size
|
||||
)
|
||||
|
||||
p2.cached_c = [None, None]
|
||||
p2.cached_uc = [None, None]
|
||||
|
||||
# Don't override user-defined dimensions.
|
||||
if not args.ad_use_inpaint_width_height:
|
||||
p2.width, p2.height = self.get_optimal_crop_image_size(
|
||||
p2.width, p2.height, pred.bboxes[j]
|
||||
)
|
||||
|
||||
@rich_traceback
|
||||
def process(self, p, *args_):
|
||||
if getattr(p, "_ad_disabled", False):
|
||||
@@ -779,7 +799,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
p.extra_generation_params.update(extra_params)
|
||||
|
||||
def _postprocess_image_inner(
|
||||
self, p, pp, args: ADetailerArgs, *, n: int = 0
|
||||
self, p, pp: PPImage, args: ADetailerArgs, *, n: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Returns
|
||||
@@ -794,23 +814,22 @@ class AfterDetailerScript(scripts.Script):
|
||||
i = get_i(p)
|
||||
|
||||
i2i = self.get_i2i_p(p, args, pp.image)
|
||||
seed, subseed = self.get_seed(p)
|
||||
ad_prompts, ad_negatives = self.get_prompt(p, args)
|
||||
|
||||
is_mediapipe = args.is_mediapipe()
|
||||
|
||||
kwargs = {}
|
||||
if is_mediapipe:
|
||||
predictor = mediapipe_predict
|
||||
ad_model = args.ad_model
|
||||
else:
|
||||
predictor = ultralytics_predict
|
||||
ad_model = self.get_ad_model(args.ad_model)
|
||||
kwargs["device"] = self.ultralytics_device
|
||||
kwargs["classes"] = args.ad_model_classes
|
||||
pred = mediapipe_predict(args.ad_model, pp.image, args.ad_confidence)
|
||||
|
||||
else:
|
||||
with change_torch_load():
|
||||
pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)
|
||||
pred = ultralytics_predict(
|
||||
args.ad_model,
|
||||
image=pp.image,
|
||||
confidence=args.ad_confidence,
|
||||
device=self.ultralytics_device,
|
||||
classes=args.ad_model_classes,
|
||||
)
|
||||
|
||||
if pred.preview is None:
|
||||
print(
|
||||
@@ -844,20 +863,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
|
||||
continue
|
||||
|
||||
p2.seed = self.get_each_tab_seed(seed, j)
|
||||
p2.subseed = self.get_each_tab_seed(subseed, j)
|
||||
p2.denoising_strength = self.get_dynamic_denoise_strength(
|
||||
p2.denoising_strength, pred.bboxes[j], pp.image.size
|
||||
)
|
||||
|
||||
p2.cached_c = [None, None]
|
||||
p2.cached_uc = [None, None]
|
||||
|
||||
# Don't override user-defined dimensions.
|
||||
if not args.ad_use_inpaint_width_height:
|
||||
p2.width, p2.height = self.get_optimal_crop_image_size(
|
||||
p2.width, p2.height, pred.bboxes[j]
|
||||
)
|
||||
self.fix_p2(p, p2, pp, args, pred, j)
|
||||
|
||||
try:
|
||||
processed = process_images(p2)
|
||||
@@ -870,6 +876,11 @@ class AfterDetailerScript(scripts.Script):
|
||||
|
||||
self.compare_prompt(p.extra_generation_params, processed, n=n)
|
||||
p2 = copy(i2i)
|
||||
|
||||
if not processed.images:
|
||||
processed = None
|
||||
break
|
||||
|
||||
p2.init_images = [processed.images[0]]
|
||||
|
||||
if processed is not None:
|
||||
@@ -879,7 +890,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
return False
|
||||
|
||||
@rich_traceback
|
||||
def postprocess_image(self, p, pp, *args_):
|
||||
def postprocess_image(self, p, pp: PPImage, *args_):
|
||||
if getattr(p, "_ad_disabled", False) or not self.is_ad_enabled(*args_):
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user