diff --git a/adetailer/common.py b/adetailer/common.py index cf3ad3f..9e1d1b9 100644 --- a/adetailer/common.py +++ b/adetailer/common.py @@ -3,11 +3,12 @@ from __future__ import annotations from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import Any, Optional from huggingface_hub import hf_hub_download from PIL import Image, ImageDraw from rich import print +from torchvision.transforms.functional import to_pil_image REPO_ID = "Bingsu/adetailer" _download_failed = False @@ -133,3 +134,11 @@ def create_bbox_from_mask( if bbox is not None: bboxes.append(list(bbox)) return bboxes + + +def ensure_pil_image(image: Any, mode: str = "RGB") -> Image.Image: + if not isinstance(image, Image.Image): + image = to_pil_image(image) + if image.mode != mode: + image = image.convert(mode) + return image diff --git a/adetailer/mask.py b/adetailer/mask.py index c76d2a8..5514828 100644 --- a/adetailer/mask.py +++ b/adetailer/mask.py @@ -3,13 +3,14 @@ from __future__ import annotations from enum import IntEnum from functools import partial, reduce from math import dist +from typing import Any import cv2 import numpy as np from PIL import Image, ImageChops from adetailer.args import MASK_MERGE_INVERT -from adetailer.common import PredictOutput +from adetailer.common import PredictOutput, ensure_pil_image class SortBy(IntEnum): @@ -89,12 +90,9 @@ def is_all_black(img: Image.Image | np.ndarray) -> bool: return cv2.countNonZero(img) == 0 -def has_intersection(im1: Image.Image, im2: Image.Image) -> bool: - if im1.mode != "L" or im2.mode != "L": - msg = "Both images must be grayscale" - raise ValueError(msg) - arr1 = np.array(im1) - arr2 = np.array(im2) +def has_intersection(im1: Any, im2: Any) -> bool: + arr1 = np.array(ensure_pil_image(im1, "L")) + arr2 = np.array(ensure_pil_image(im2, "L")) return not is_all_black(cv2.bitwise_and(arr1, arr2)) diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 534eae0..18f6382 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -16,7 +16,6 @@ import gradio as gr import torch from PIL import Image from rich import print -from torchvision.transforms.functional import to_pil_image import modules from adetailer import ( @@ -27,7 +26,7 @@ from adetailer import ( ultralytics_predict, ) from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, SkipImg2ImgOrig -from adetailer.common import PredictOutput +from adetailer.common import PredictOutput, ensure_pil_image from adetailer.mask import ( filter_by_ratio, filter_k_largest, @@ -582,14 +581,6 @@ class AfterDetailerScript(scripts.Script): masks = self.inpaint_mask_filter(p.image_mask, masks) return masks - @staticmethod - def ensure_rgb_image(image: Any): - if not isinstance(image, Image.Image): - image = to_pil_image(image) - if image.mode != "RGB": - image = image.convert("RGB") - return image - @staticmethod def i2i_prompts_replace( i2i, prompts: list[str], negative_prompts: list[str], j: int @@ -737,7 +728,7 @@ class AfterDetailerScript(scripts.Script): p2 = copy(i2i) for j in range(steps): p2.image_mask = masks[j] - p2.init_images[0] = self.ensure_rgb_image(p2.init_images[0]) + p2.init_images[0] = ensure_pil_image(p2.init_images[0], "RGB") self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j) if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt): @@ -771,7 +762,7 @@ class AfterDetailerScript(scripts.Script): return pp.image = self.get_i2i_init_image(p, pp) - pp.image = self.ensure_rgb_image(pp.image) + pp.image = ensure_pil_image(pp.image, "RGB") init_image = copy(pp.image) arg_list = self.get_args(p, *args_) params_txt_content = Path(paths.data_path, "params.txt").read_text("utf-8")