mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-03-13 17:30:01 +00:00
fix: ensure pil image, mask has intersection
This commit is contained in:
@@ -3,11 +3,12 @@ from __future__ import annotations
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image, ImageDraw
|
||||
from rich import print
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
REPO_ID = "Bingsu/adetailer"
|
||||
_download_failed = False
|
||||
@@ -133,3 +134,11 @@ def create_bbox_from_mask(
|
||||
if bbox is not None:
|
||||
bboxes.append(list(bbox))
|
||||
return bboxes
|
||||
|
||||
|
||||
def ensure_pil_image(image: Any, mode: str = "RGB") -> Image.Image:
|
||||
if not isinstance(image, Image.Image):
|
||||
image = to_pil_image(image)
|
||||
if image.mode != mode:
|
||||
image = image.convert(mode)
|
||||
return image
|
||||
|
||||
@@ -3,13 +3,14 @@ from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from functools import partial, reduce
|
||||
from math import dist
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from adetailer.args import MASK_MERGE_INVERT
|
||||
from adetailer.common import PredictOutput
|
||||
from adetailer.common import PredictOutput, ensure_pil_image
|
||||
|
||||
|
||||
class SortBy(IntEnum):
|
||||
@@ -89,12 +90,9 @@ def is_all_black(img: Image.Image | np.ndarray) -> bool:
|
||||
return cv2.countNonZero(img) == 0
|
||||
|
||||
|
||||
def has_intersection(im1: Image.Image, im2: Image.Image) -> bool:
|
||||
if im1.mode != "L" or im2.mode != "L":
|
||||
msg = "Both images must be grayscale"
|
||||
raise ValueError(msg)
|
||||
arr1 = np.array(im1)
|
||||
arr2 = np.array(im2)
|
||||
def has_intersection(im1: Any, im2: Any) -> bool:
|
||||
arr1 = np.array(ensure_pil_image(im1, "L"))
|
||||
arr2 = np.array(ensure_pil_image(im2, "L"))
|
||||
return not is_all_black(cv2.bitwise_and(arr1, arr2))
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import gradio as gr
|
||||
import torch
|
||||
from PIL import Image
|
||||
from rich import print
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
import modules
|
||||
from adetailer import (
|
||||
@@ -27,7 +26,7 @@ from adetailer import (
|
||||
ultralytics_predict,
|
||||
)
|
||||
from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, SkipImg2ImgOrig
|
||||
from adetailer.common import PredictOutput
|
||||
from adetailer.common import PredictOutput, ensure_pil_image
|
||||
from adetailer.mask import (
|
||||
filter_by_ratio,
|
||||
filter_k_largest,
|
||||
@@ -582,14 +581,6 @@ class AfterDetailerScript(scripts.Script):
|
||||
masks = self.inpaint_mask_filter(p.image_mask, masks)
|
||||
return masks
|
||||
|
||||
@staticmethod
|
||||
def ensure_rgb_image(image: Any):
|
||||
if not isinstance(image, Image.Image):
|
||||
image = to_pil_image(image)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def i2i_prompts_replace(
|
||||
i2i, prompts: list[str], negative_prompts: list[str], j: int
|
||||
@@ -737,7 +728,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
p2 = copy(i2i)
|
||||
for j in range(steps):
|
||||
p2.image_mask = masks[j]
|
||||
p2.init_images[0] = self.ensure_rgb_image(p2.init_images[0])
|
||||
p2.init_images[0] = ensure_pil_image(p2.init_images[0], "RGB")
|
||||
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
|
||||
|
||||
if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
|
||||
@@ -771,7 +762,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
return
|
||||
|
||||
pp.image = self.get_i2i_init_image(p, pp)
|
||||
pp.image = self.ensure_rgb_image(pp.image)
|
||||
pp.image = ensure_pil_image(pp.image, "RGB")
|
||||
init_image = copy(pp.image)
|
||||
arg_list = self.get_args(p, *args_)
|
||||
params_txt_content = Path(paths.data_path, "params.txt").read_text("utf-8")
|
||||
|
||||
Reference in New Issue
Block a user