fix: ensure pil image, mask has intersection

This commit is contained in:
Dowon
2024-03-16 14:50:00 +09:00
parent a9feea2662
commit f6a7b95585
3 changed files with 18 additions and 20 deletions

View File

@@ -16,7 +16,6 @@ import gradio as gr
import torch
from PIL import Image
from rich import print
from torchvision.transforms.functional import to_pil_image
import modules
from adetailer import (
@@ -27,7 +26,7 @@ from adetailer import (
ultralytics_predict,
)
from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, SkipImg2ImgOrig
from adetailer.common import PredictOutput
from adetailer.common import PredictOutput, ensure_pil_image
from adetailer.mask import (
filter_by_ratio,
filter_k_largest,
@@ -582,14 +581,6 @@ class AfterDetailerScript(scripts.Script):
masks = self.inpaint_mask_filter(p.image_mask, masks)
return masks
@staticmethod
def ensure_rgb_image(image: Any):
if not isinstance(image, Image.Image):
image = to_pil_image(image)
if image.mode != "RGB":
image = image.convert("RGB")
return image
@staticmethod
def i2i_prompts_replace(
i2i, prompts: list[str], negative_prompts: list[str], j: int
@@ -737,7 +728,7 @@ class AfterDetailerScript(scripts.Script):
p2 = copy(i2i)
for j in range(steps):
p2.image_mask = masks[j]
p2.init_images[0] = self.ensure_rgb_image(p2.init_images[0])
p2.init_images[0] = ensure_pil_image(p2.init_images[0], "RGB")
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
@@ -771,7 +762,7 @@ class AfterDetailerScript(scripts.Script):
return
pp.image = self.get_i2i_init_image(p, pp)
pp.image = self.ensure_rgb_image(pp.image)
pp.image = ensure_pil_image(pp.image, "RGB")
init_image = copy(pp.image)
arg_list = self.get_args(p, *args_)
params_txt_content = Path(paths.data_path, "params.txt").read_text("utf-8")