From a1b498ff3986fd9feed8e780b17383a8d7d2fb95 Mon Sep 17 00:00:00 2001 From: JudahGazit Date: Tue, 8 Aug 2023 14:15:37 +0300 Subject: [PATCH] feat: Masking only the top k largest predictions (#264) --- adetailer/args.py | 2 ++ adetailer/mask.py | 10 ++++++++++ adetailer/ui.py | 12 +++++++++++- scripts/!adetailer.py | 3 ++- 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/adetailer/args.py b/adetailer/args.py index f12e4b2..03ebdf4 100644 --- a/adetailer/args.py +++ b/adetailer/args.py @@ -70,6 +70,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0 ad_controlnet_guidance_start: confloat(ge=0.0, le=1.0) = 0.0 ad_controlnet_guidance_end: confloat(ge=0.0, le=1.0) = 1.0 + ad_mask_k_largest: NonNegativeInt = 0 is_api: bool = True @root_validator(skip_on_failure=True) @@ -190,6 +191,7 @@ _all_args = [ ("ad_confidence", "ADetailer confidence"), ("ad_mask_min_ratio", "ADetailer mask min ratio"), ("ad_mask_max_ratio", "ADetailer mask max ratio"), + ("ad_mask_k_largest", "ADetailer mask only top k largest"), ("ad_x_offset", "ADetailer x offset"), ("ad_y_offset", "ADetailer y offset"), ("ad_dilate_erode", "ADetailer dilate/erode"), diff --git a/adetailer/mask.py b/adetailer/mask.py index 9209b45..eaa90c8 100644 --- a/adetailer/mask.py +++ b/adetailer/mask.py @@ -215,6 +215,16 @@ def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutp return pred +def filter_take_largest(pred: PredictOutput, k: int) -> PredictOutput: + if not pred.bboxes or k == 0: + return pred + areas = [bbox_area(bbox) for bbox in pred.bboxes] + idx = np.argsort(areas)[-k:] + pred.bboxes = [pred.bboxes[i] for i in idx] + pred.masks = [pred.masks[i] for i in idx] + return pred + + # Merge / Invert def mask_merge(masks: list[Image.Image]) -> list[Image.Image]: arrs = [np.array(m) for m in masks] diff --git a/adetailer/ui.py b/adetailer/ui.py index b4e585e..10d6d67 100644 --- a/adetailer/ui.py +++ b/adetailer/ui.py @@ -194,7 +194,7 @@ def detection(w: Widgets, n: int, is_img2img: bool): eid = partial(elem_id, n=n, is_img2img=is_img2img) with gr.Row(): - with gr.Column(): + with gr.Column(variant="compact"): w.ad_confidence = gr.Slider( label="Detection model confidence threshold" + suffix(n), minimum=0.0, @@ -204,6 +204,15 @@ def detection(w: Widgets, n: int, is_img2img: bool): visible=True, elem_id=eid("ad_confidence"), ) + w.ad_mask_k_largest = gr.Slider( + label="Mask only the top k largest (0 to disable)" + suffix(n), + minumum=0, + maximum=5, + step=1, + value=0, + visible=True, + elem_id=eid("ad_mask_k_largest") + ) with gr.Column(variant="compact"): w.ad_mask_min_ratio = gr.Slider( @@ -226,6 +235,7 @@ def detection(w: Widgets, n: int, is_img2img: bool): ) + def mask_preprocessing(w: Widgets, n: int, is_img2img: bool): eid = partial(elem_id, n=n, is_img2img=is_img2img) diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index a2747c9..7b6e2fd 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -26,7 +26,7 @@ from adetailer import ( ) from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, EnableChecker from adetailer.common import PredictOutput -from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes +from adetailer.mask import filter_take_largest, filter_by_ratio, mask_preprocess, sort_bboxes from adetailer.traceback import rich_traceback from adetailer.ui import adui, ordinal, suffix from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models @@ -463,6 +463,7 @@ class AfterDetailerScript(scripts.Script): pred = filter_by_ratio( pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio ) + pred = filter_take_largest(pred, k=args.ad_mask_k_largest) pred = self.sort_bboxes(pred) return mask_preprocess( pred.masks,