diff --git a/adetailer/mask.py b/adetailer/mask.py index d2f3680..c76d2a8 100644 --- a/adetailer/mask.py +++ b/adetailer/mask.py @@ -83,12 +83,22 @@ def offset(img: Image.Image, x: int = 0, y: int = 0) -> Image.Image: return ImageChops.offset(img, x, -y) -def is_all_black(img: Image.Image) -> bool: - arr = np.array(img) - return cv2.countNonZero(arr) == 0 +def is_all_black(img: Image.Image | np.ndarray) -> bool: + if isinstance(img, Image.Image): + img = np.array(img) + return cv2.countNonZero(img) == 0 -def bbox_area(bbox: list[float]): +def has_intersection(im1: Image.Image, im2: Image.Image) -> bool: + if im1.mode != "L" or im2.mode != "L": + msg = "Both images must be grayscale" + raise ValueError(msg) + arr1 = np.array(im1) + arr2 = np.array(im2) + return not is_all_black(cv2.bitwise_and(arr1, arr2)) + + +def bbox_area(bbox: list[float]) -> float: return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 8fa8352..8174509 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -31,6 +31,8 @@ from adetailer.common import PredictOutput from adetailer.mask import ( filter_by_ratio, filter_k_largest, + has_intersection, + is_all_black, mask_preprocess, sort_bboxes, ) @@ -637,16 +639,24 @@ class AfterDetailerScript(scripts.Script): @staticmethod def is_img2img_inpaint(p) -> bool: - return hasattr(p, "image_mask") and bool(p.image_mask) + return hasattr(p, "image_mask") and p.image_mask is not None + + @staticmethod + def inpaint_mask_filter( + img2img_mask: Image.Image, ad_mask: list[Image.Image] + ) -> list[Image.Image]: + return [mask for mask in ad_mask if has_intersection(img2img_mask, mask)] @rich_traceback def process(self, p, *args_): if getattr(p, "_ad_disabled", False): return - if self.is_img2img_inpaint(p): + if self.is_img2img_inpaint(p) and is_all_black(p.image_mask): p._ad_disabled = True - msg = "[-] ADetailer: img2img inpainting detected. adetailer disabled." + msg = ( + "[-] ADetailer: img2img inpainting with no mask -- adetailer disabled." + ) print(msg) return @@ -701,6 +711,8 @@ class AfterDetailerScript(scripts.Script): pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs) masks = self.pred_preprocessing(pred, args) + if self.is_img2img_inpaint(p): + masks = self.inpaint_mask_filter(p.image_mask, masks) shared.state.assign_current_image(pred.preview) if not masks: