fix: ensure pil image, mask has intersection

This commit is contained in:
Dowon
2024-03-16 14:50:00 +09:00
parent a9feea2662
commit f6a7b95585
3 changed files with 18 additions and 20 deletions

View File

@@ -3,11 +3,12 @@ from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import Any, Optional
from huggingface_hub import hf_hub_download
from PIL import Image, ImageDraw
from rich import print
from torchvision.transforms.functional import to_pil_image
REPO_ID = "Bingsu/adetailer"
_download_failed = False
@@ -133,3 +134,11 @@ def create_bbox_from_mask(
if bbox is not None:
bboxes.append(list(bbox))
return bboxes
def ensure_pil_image(image: Any, mode: str = "RGB") -> Image.Image:
if not isinstance(image, Image.Image):
image = to_pil_image(image)
if image.mode != mode:
image = image.convert(mode)
return image

View File

@@ -3,13 +3,14 @@ from __future__ import annotations
from enum import IntEnum
from functools import partial, reduce
from math import dist
from typing import Any
import cv2
import numpy as np
from PIL import Image, ImageChops
from adetailer.args import MASK_MERGE_INVERT
from adetailer.common import PredictOutput
from adetailer.common import PredictOutput, ensure_pil_image
class SortBy(IntEnum):
@@ -89,12 +90,9 @@ def is_all_black(img: Image.Image | np.ndarray) -> bool:
return cv2.countNonZero(img) == 0
def has_intersection(im1: Image.Image, im2: Image.Image) -> bool:
if im1.mode != "L" or im2.mode != "L":
msg = "Both images must be grayscale"
raise ValueError(msg)
arr1 = np.array(im1)
arr2 = np.array(im2)
def has_intersection(im1: Any, im2: Any) -> bool:
arr1 = np.array(ensure_pil_image(im1, "L"))
arr2 = np.array(ensure_pil_image(im2, "L"))
return not is_all_black(cv2.bitwise_and(arr1, arr2))

View File

@@ -16,7 +16,6 @@ import gradio as gr
import torch
from PIL import Image
from rich import print
from torchvision.transforms.functional import to_pil_image
import modules
from adetailer import (
@@ -27,7 +26,7 @@ from adetailer import (
ultralytics_predict,
)
from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, SkipImg2ImgOrig
from adetailer.common import PredictOutput
from adetailer.common import PredictOutput, ensure_pil_image
from adetailer.mask import (
filter_by_ratio,
filter_k_largest,
@@ -582,14 +581,6 @@ class AfterDetailerScript(scripts.Script):
masks = self.inpaint_mask_filter(p.image_mask, masks)
return masks
@staticmethod
def ensure_rgb_image(image: Any):
if not isinstance(image, Image.Image):
image = to_pil_image(image)
if image.mode != "RGB":
image = image.convert("RGB")
return image
@staticmethod
def i2i_prompts_replace(
i2i, prompts: list[str], negative_prompts: list[str], j: int
@@ -737,7 +728,7 @@ class AfterDetailerScript(scripts.Script):
p2 = copy(i2i)
for j in range(steps):
p2.image_mask = masks[j]
p2.init_images[0] = self.ensure_rgb_image(p2.init_images[0])
p2.init_images[0] = ensure_pil_image(p2.init_images[0], "RGB")
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
@@ -771,7 +762,7 @@ class AfterDetailerScript(scripts.Script):
return
pp.image = self.get_i2i_init_image(p, pp)
pp.image = self.ensure_rgb_image(pp.image)
pp.image = ensure_pil_image(pp.image, "RGB")
init_image = copy(pp.image)
arg_list = self.get_args(p, *args_)
params_txt_content = Path(paths.data_path, "params.txt").read_text("utf-8")