mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-04-24 00:09:13 +00:00
fix: ensure pil image, mask has intersection
This commit is contained in:
@@ -16,7 +16,6 @@ import gradio as gr
|
||||
import torch
|
||||
from PIL import Image
|
||||
from rich import print
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
import modules
|
||||
from adetailer import (
|
||||
@@ -27,7 +26,7 @@ from adetailer import (
|
||||
ultralytics_predict,
|
||||
)
|
||||
from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, SkipImg2ImgOrig
|
||||
from adetailer.common import PredictOutput
|
||||
from adetailer.common import PredictOutput, ensure_pil_image
|
||||
from adetailer.mask import (
|
||||
filter_by_ratio,
|
||||
filter_k_largest,
|
||||
@@ -582,14 +581,6 @@ class AfterDetailerScript(scripts.Script):
|
||||
masks = self.inpaint_mask_filter(p.image_mask, masks)
|
||||
return masks
|
||||
|
||||
@staticmethod
|
||||
def ensure_rgb_image(image: Any):
|
||||
if not isinstance(image, Image.Image):
|
||||
image = to_pil_image(image)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def i2i_prompts_replace(
|
||||
i2i, prompts: list[str], negative_prompts: list[str], j: int
|
||||
@@ -737,7 +728,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
p2 = copy(i2i)
|
||||
for j in range(steps):
|
||||
p2.image_mask = masks[j]
|
||||
p2.init_images[0] = self.ensure_rgb_image(p2.init_images[0])
|
||||
p2.init_images[0] = ensure_pil_image(p2.init_images[0], "RGB")
|
||||
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
|
||||
|
||||
if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
|
||||
@@ -771,7 +762,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
return
|
||||
|
||||
pp.image = self.get_i2i_init_image(p, pp)
|
||||
pp.image = self.ensure_rgb_image(pp.image)
|
||||
pp.image = ensure_pil_image(pp.image, "RGB")
|
||||
init_image = copy(pp.image)
|
||||
arg_list = self.get_args(p, *args_)
|
||||
params_txt_content = Path(paths.data_path, "params.txt").read_text("utf-8")
|
||||
|
||||
Reference in New Issue
Block a user