refactor(script): reduce complexity

This commit is contained in:
Dowon
2024-08-24 13:51:11 +09:00
parent 18d8db995f
commit 79a74819cb
2 changed files with 44 additions and 27 deletions

View File

@@ -5,6 +5,8 @@ from copy import copy
from typing import TYPE_CHECKING, Any, Union from typing import TYPE_CHECKING, Any, Union
import torch import torch
from PIL import Image
from typing_extensions import Protocol
from modules import safe from modules import safe
from modules.shared import opts from modules.shared import opts
@@ -57,3 +59,7 @@ def preserve_prompts(p: PT):
def copy_extra_params(extra_params: dict[str, Any]) -> dict[str, Any]: def copy_extra_params(extra_params: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in extra_params.items() if not callable(v)} return {k: v for k, v in extra_params.items() if not callable(v)}
class PPImage(Protocol):
image: Image.Image

View File

@@ -18,6 +18,7 @@ from rich import print
import modules import modules
from aaaaaa.conditional import create_binary_mask, schedulers from aaaaaa.conditional import create_binary_mask, schedulers
from aaaaaa.helper import ( from aaaaaa.helper import (
PPImage,
change_torch_load, change_torch_load,
copy_extra_params, copy_extra_params,
pause_total_tqdm, pause_total_tqdm,
@@ -744,6 +745,25 @@ class AfterDetailerScript(scripts.Script):
return optimal_resolution return optimal_resolution
def fix_p2(
self, p, p2, pp: PPImage, args: ADetailerArgs, pred: PredictOutput, j: int
):
seed, subseed = self.get_seed(p)
p2.seed = self.get_each_tab_seed(seed, j)
p2.subseed = self.get_each_tab_seed(subseed, j)
p2.denoising_strength = self.get_dynamic_denoise_strength(
p2.denoising_strength, pred.bboxes[j], pp.image.size
)
p2.cached_c = [None, None]
p2.cached_uc = [None, None]
# Don't override user-defined dimensions.
if not args.ad_use_inpaint_width_height:
p2.width, p2.height = self.get_optimal_crop_image_size(
p2.width, p2.height, pred.bboxes[j]
)
@rich_traceback @rich_traceback
def process(self, p, *args_): def process(self, p, *args_):
if getattr(p, "_ad_disabled", False): if getattr(p, "_ad_disabled", False):
@@ -779,7 +799,7 @@ class AfterDetailerScript(scripts.Script):
p.extra_generation_params.update(extra_params) p.extra_generation_params.update(extra_params)
def _postprocess_image_inner( def _postprocess_image_inner(
self, p, pp, args: ADetailerArgs, *, n: int = 0 self, p, pp: PPImage, args: ADetailerArgs, *, n: int = 0
) -> bool: ) -> bool:
""" """
Returns Returns
@@ -794,23 +814,22 @@ class AfterDetailerScript(scripts.Script):
i = get_i(p) i = get_i(p)
i2i = self.get_i2i_p(p, args, pp.image) i2i = self.get_i2i_p(p, args, pp.image)
seed, subseed = self.get_seed(p)
ad_prompts, ad_negatives = self.get_prompt(p, args) ad_prompts, ad_negatives = self.get_prompt(p, args)
is_mediapipe = args.is_mediapipe() is_mediapipe = args.is_mediapipe()
kwargs = {}
if is_mediapipe: if is_mediapipe:
predictor = mediapipe_predict pred = mediapipe_predict(args.ad_model, pp.image, args.ad_confidence)
ad_model = args.ad_model
else:
predictor = ultralytics_predict
ad_model = self.get_ad_model(args.ad_model)
kwargs["device"] = self.ultralytics_device
kwargs["classes"] = args.ad_model_classes
with change_torch_load(): else:
pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs) with change_torch_load():
pred = ultralytics_predict(
args.ad_model,
image=pp.image,
confidence=args.ad_confidence,
device=self.ultralytics_device,
classes=args.ad_model_classes,
)
if pred.preview is None: if pred.preview is None:
print( print(
@@ -844,20 +863,7 @@ class AfterDetailerScript(scripts.Script):
if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt): if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
continue continue
p2.seed = self.get_each_tab_seed(seed, j) self.fix_p2(p, p2, pp, args, pred, j)
p2.subseed = self.get_each_tab_seed(subseed, j)
p2.denoising_strength = self.get_dynamic_denoise_strength(
p2.denoising_strength, pred.bboxes[j], pp.image.size
)
p2.cached_c = [None, None]
p2.cached_uc = [None, None]
# Don't override user-defined dimensions.
if not args.ad_use_inpaint_width_height:
p2.width, p2.height = self.get_optimal_crop_image_size(
p2.width, p2.height, pred.bboxes[j]
)
try: try:
processed = process_images(p2) processed = process_images(p2)
@@ -870,6 +876,11 @@ class AfterDetailerScript(scripts.Script):
self.compare_prompt(p.extra_generation_params, processed, n=n) self.compare_prompt(p.extra_generation_params, processed, n=n)
p2 = copy(i2i) p2 = copy(i2i)
if not processed.images:
processed = None
break
p2.init_images = [processed.images[0]] p2.init_images = [processed.images[0]]
if processed is not None: if processed is not None:
@@ -879,7 +890,7 @@ class AfterDetailerScript(scripts.Script):
return False return False
@rich_traceback @rich_traceback
def postprocess_image(self, p, pp, *args_): def postprocess_image(self, p, pp: PPImage, *args_):
if getattr(p, "_ad_disabled", False) or not self.is_ad_enabled(*args_): if getattr(p, "_ad_disabled", False) or not self.is_ad_enabled(*args_):
return return