mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-05-01 03:31:21 +00:00
refactor(script): reduce complexity
This commit is contained in:
@@ -5,6 +5,8 @@ from copy import copy
|
|||||||
from typing import TYPE_CHECKING, Any, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from modules import safe
|
from modules import safe
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
@@ -57,3 +59,7 @@ def preserve_prompts(p: PT):
|
|||||||
|
|
||||||
def copy_extra_params(extra_params: dict[str, Any]) -> dict[str, Any]:
|
def copy_extra_params(extra_params: dict[str, Any]) -> dict[str, Any]:
|
||||||
return {k: v for k, v in extra_params.items() if not callable(v)}
|
return {k: v for k, v in extra_params.items() if not callable(v)}
|
||||||
|
|
||||||
|
|
||||||
|
class PPImage(Protocol):
|
||||||
|
image: Image.Image
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from rich import print
|
|||||||
import modules
|
import modules
|
||||||
from aaaaaa.conditional import create_binary_mask, schedulers
|
from aaaaaa.conditional import create_binary_mask, schedulers
|
||||||
from aaaaaa.helper import (
|
from aaaaaa.helper import (
|
||||||
|
PPImage,
|
||||||
change_torch_load,
|
change_torch_load,
|
||||||
copy_extra_params,
|
copy_extra_params,
|
||||||
pause_total_tqdm,
|
pause_total_tqdm,
|
||||||
@@ -744,6 +745,25 @@ class AfterDetailerScript(scripts.Script):
|
|||||||
|
|
||||||
return optimal_resolution
|
return optimal_resolution
|
||||||
|
|
||||||
|
def fix_p2(
|
||||||
|
self, p, p2, pp: PPImage, args: ADetailerArgs, pred: PredictOutput, j: int
|
||||||
|
):
|
||||||
|
seed, subseed = self.get_seed(p)
|
||||||
|
p2.seed = self.get_each_tab_seed(seed, j)
|
||||||
|
p2.subseed = self.get_each_tab_seed(subseed, j)
|
||||||
|
p2.denoising_strength = self.get_dynamic_denoise_strength(
|
||||||
|
p2.denoising_strength, pred.bboxes[j], pp.image.size
|
||||||
|
)
|
||||||
|
|
||||||
|
p2.cached_c = [None, None]
|
||||||
|
p2.cached_uc = [None, None]
|
||||||
|
|
||||||
|
# Don't override user-defined dimensions.
|
||||||
|
if not args.ad_use_inpaint_width_height:
|
||||||
|
p2.width, p2.height = self.get_optimal_crop_image_size(
|
||||||
|
p2.width, p2.height, pred.bboxes[j]
|
||||||
|
)
|
||||||
|
|
||||||
@rich_traceback
|
@rich_traceback
|
||||||
def process(self, p, *args_):
|
def process(self, p, *args_):
|
||||||
if getattr(p, "_ad_disabled", False):
|
if getattr(p, "_ad_disabled", False):
|
||||||
@@ -779,7 +799,7 @@ class AfterDetailerScript(scripts.Script):
|
|||||||
p.extra_generation_params.update(extra_params)
|
p.extra_generation_params.update(extra_params)
|
||||||
|
|
||||||
def _postprocess_image_inner(
|
def _postprocess_image_inner(
|
||||||
self, p, pp, args: ADetailerArgs, *, n: int = 0
|
self, p, pp: PPImage, args: ADetailerArgs, *, n: int = 0
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns
|
Returns
|
||||||
@@ -794,23 +814,22 @@ class AfterDetailerScript(scripts.Script):
|
|||||||
i = get_i(p)
|
i = get_i(p)
|
||||||
|
|
||||||
i2i = self.get_i2i_p(p, args, pp.image)
|
i2i = self.get_i2i_p(p, args, pp.image)
|
||||||
seed, subseed = self.get_seed(p)
|
|
||||||
ad_prompts, ad_negatives = self.get_prompt(p, args)
|
ad_prompts, ad_negatives = self.get_prompt(p, args)
|
||||||
|
|
||||||
is_mediapipe = args.is_mediapipe()
|
is_mediapipe = args.is_mediapipe()
|
||||||
|
|
||||||
kwargs = {}
|
|
||||||
if is_mediapipe:
|
if is_mediapipe:
|
||||||
predictor = mediapipe_predict
|
pred = mediapipe_predict(args.ad_model, pp.image, args.ad_confidence)
|
||||||
ad_model = args.ad_model
|
|
||||||
else:
|
|
||||||
predictor = ultralytics_predict
|
|
||||||
ad_model = self.get_ad_model(args.ad_model)
|
|
||||||
kwargs["device"] = self.ultralytics_device
|
|
||||||
kwargs["classes"] = args.ad_model_classes
|
|
||||||
|
|
||||||
with change_torch_load():
|
else:
|
||||||
pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)
|
with change_torch_load():
|
||||||
|
pred = ultralytics_predict(
|
||||||
|
args.ad_model,
|
||||||
|
image=pp.image,
|
||||||
|
confidence=args.ad_confidence,
|
||||||
|
device=self.ultralytics_device,
|
||||||
|
classes=args.ad_model_classes,
|
||||||
|
)
|
||||||
|
|
||||||
if pred.preview is None:
|
if pred.preview is None:
|
||||||
print(
|
print(
|
||||||
@@ -844,20 +863,7 @@ class AfterDetailerScript(scripts.Script):
|
|||||||
if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
|
if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
p2.seed = self.get_each_tab_seed(seed, j)
|
self.fix_p2(p, p2, pp, args, pred, j)
|
||||||
p2.subseed = self.get_each_tab_seed(subseed, j)
|
|
||||||
p2.denoising_strength = self.get_dynamic_denoise_strength(
|
|
||||||
p2.denoising_strength, pred.bboxes[j], pp.image.size
|
|
||||||
)
|
|
||||||
|
|
||||||
p2.cached_c = [None, None]
|
|
||||||
p2.cached_uc = [None, None]
|
|
||||||
|
|
||||||
# Don't override user-defined dimensions.
|
|
||||||
if not args.ad_use_inpaint_width_height:
|
|
||||||
p2.width, p2.height = self.get_optimal_crop_image_size(
|
|
||||||
p2.width, p2.height, pred.bboxes[j]
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
processed = process_images(p2)
|
processed = process_images(p2)
|
||||||
@@ -870,6 +876,11 @@ class AfterDetailerScript(scripts.Script):
|
|||||||
|
|
||||||
self.compare_prompt(p.extra_generation_params, processed, n=n)
|
self.compare_prompt(p.extra_generation_params, processed, n=n)
|
||||||
p2 = copy(i2i)
|
p2 = copy(i2i)
|
||||||
|
|
||||||
|
if not processed.images:
|
||||||
|
processed = None
|
||||||
|
break
|
||||||
|
|
||||||
p2.init_images = [processed.images[0]]
|
p2.init_images = [processed.images[0]]
|
||||||
|
|
||||||
if processed is not None:
|
if processed is not None:
|
||||||
@@ -879,7 +890,7 @@ class AfterDetailerScript(scripts.Script):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@rich_traceback
|
@rich_traceback
|
||||||
def postprocess_image(self, p, pp, *args_):
|
def postprocess_image(self, p, pp: PPImage, *args_):
|
||||||
if getattr(p, "_ad_disabled", False) or not self.is_ad_enabled(*args_):
|
if getattr(p, "_ad_disabled", False) or not self.is_ad_enabled(*args_):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user