refactor(script): reduce complexity

This commit is contained in:
Dowon
2024-08-24 13:51:11 +09:00
parent 18d8db995f
commit 79a74819cb
2 changed files with 44 additions and 27 deletions

View File

@@ -5,6 +5,8 @@ from copy import copy
from typing import TYPE_CHECKING, Any, Union
import torch
from PIL import Image
from typing_extensions import Protocol
from modules import safe
from modules.shared import opts
@@ -57,3 +59,7 @@ def preserve_prompts(p: PT):
def copy_extra_params(extra_params: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in extra_params.items() if not callable(v)}
class PPImage(Protocol):
image: Image.Image

View File

@@ -18,6 +18,7 @@ from rich import print
import modules
from aaaaaa.conditional import create_binary_mask, schedulers
from aaaaaa.helper import (
PPImage,
change_torch_load,
copy_extra_params,
pause_total_tqdm,
@@ -744,6 +745,25 @@ class AfterDetailerScript(scripts.Script):
return optimal_resolution
def fix_p2(
self, p, p2, pp: PPImage, args: ADetailerArgs, pred: PredictOutput, j: int
):
seed, subseed = self.get_seed(p)
p2.seed = self.get_each_tab_seed(seed, j)
p2.subseed = self.get_each_tab_seed(subseed, j)
p2.denoising_strength = self.get_dynamic_denoise_strength(
p2.denoising_strength, pred.bboxes[j], pp.image.size
)
p2.cached_c = [None, None]
p2.cached_uc = [None, None]
# Don't override user-defined dimensions.
if not args.ad_use_inpaint_width_height:
p2.width, p2.height = self.get_optimal_crop_image_size(
p2.width, p2.height, pred.bboxes[j]
)
@rich_traceback
def process(self, p, *args_):
if getattr(p, "_ad_disabled", False):
@@ -779,7 +799,7 @@ class AfterDetailerScript(scripts.Script):
p.extra_generation_params.update(extra_params)
def _postprocess_image_inner(
self, p, pp, args: ADetailerArgs, *, n: int = 0
self, p, pp: PPImage, args: ADetailerArgs, *, n: int = 0
) -> bool:
"""
Returns
@@ -794,23 +814,22 @@ class AfterDetailerScript(scripts.Script):
i = get_i(p)
i2i = self.get_i2i_p(p, args, pp.image)
seed, subseed = self.get_seed(p)
ad_prompts, ad_negatives = self.get_prompt(p, args)
is_mediapipe = args.is_mediapipe()
kwargs = {}
if is_mediapipe:
predictor = mediapipe_predict
ad_model = args.ad_model
else:
predictor = ultralytics_predict
ad_model = self.get_ad_model(args.ad_model)
kwargs["device"] = self.ultralytics_device
kwargs["classes"] = args.ad_model_classes
pred = mediapipe_predict(args.ad_model, pp.image, args.ad_confidence)
else:
with change_torch_load():
pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)
pred = ultralytics_predict(
args.ad_model,
image=pp.image,
confidence=args.ad_confidence,
device=self.ultralytics_device,
classes=args.ad_model_classes,
)
if pred.preview is None:
print(
@@ -844,20 +863,7 @@ class AfterDetailerScript(scripts.Script):
if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
continue
p2.seed = self.get_each_tab_seed(seed, j)
p2.subseed = self.get_each_tab_seed(subseed, j)
p2.denoising_strength = self.get_dynamic_denoise_strength(
p2.denoising_strength, pred.bboxes[j], pp.image.size
)
p2.cached_c = [None, None]
p2.cached_uc = [None, None]
# Don't override user-defined dimensions.
if not args.ad_use_inpaint_width_height:
p2.width, p2.height = self.get_optimal_crop_image_size(
p2.width, p2.height, pred.bboxes[j]
)
self.fix_p2(p, p2, pp, args, pred, j)
try:
processed = process_images(p2)
@@ -870,6 +876,11 @@ class AfterDetailerScript(scripts.Script):
self.compare_prompt(p.extra_generation_params, processed, n=n)
p2 = copy(i2i)
if not processed.images:
processed = None
break
p2.init_images = [processed.images[0]]
if processed is not None:
@@ -879,7 +890,7 @@ class AfterDetailerScript(scripts.Script):
return False
@rich_traceback
def postprocess_image(self, p, pp, *args_):
def postprocess_image(self, p, pp: PPImage, *args_):
if getattr(p, "_ad_disabled", False) or not self.is_ad_enabled(*args_):
return