feat: Masking only the top k largest predictions (#264)

This commit is contained in:
JudahGazit
2023-08-08 14:15:37 +03:00
committed by GitHub
parent 3e80b2d824
commit a1b498ff39
4 changed files with 25 additions and 2 deletions

View File

@@ -70,6 +70,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0
ad_controlnet_guidance_start: confloat(ge=0.0, le=1.0) = 0.0
ad_controlnet_guidance_end: confloat(ge=0.0, le=1.0) = 1.0
ad_mask_k_largest: NonNegativeInt = 0
is_api: bool = True
@root_validator(skip_on_failure=True)
@@ -190,6 +191,7 @@ _all_args = [
("ad_confidence", "ADetailer confidence"),
("ad_mask_min_ratio", "ADetailer mask min ratio"),
("ad_mask_max_ratio", "ADetailer mask max ratio"),
("ad_mask_k_largest", "ADetailer mask only top k largest"),
("ad_x_offset", "ADetailer x offset"),
("ad_y_offset", "ADetailer y offset"),
("ad_dilate_erode", "ADetailer dilate/erode"),

View File

@@ -215,6 +215,16 @@ def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutp
return pred
def filter_take_largest(pred: PredictOutput, k: int) -> PredictOutput:
if not pred.bboxes or k == 0:
return pred
areas = [bbox_area(bbox) for bbox in pred.bboxes]
idx = np.argsort(areas)[-k:]
pred.bboxes = [pred.bboxes[i] for i in idx]
pred.masks = [pred.masks[i] for i in idx]
return pred
# Merge / Invert
def mask_merge(masks: list[Image.Image]) -> list[Image.Image]:
arrs = [np.array(m) for m in masks]

View File

@@ -194,7 +194,7 @@ def detection(w: Widgets, n: int, is_img2img: bool):
eid = partial(elem_id, n=n, is_img2img=is_img2img)
with gr.Row():
with gr.Column():
with gr.Column(variant="compact"):
w.ad_confidence = gr.Slider(
label="Detection model confidence threshold" + suffix(n),
minimum=0.0,
@@ -204,6 +204,15 @@ def detection(w: Widgets, n: int, is_img2img: bool):
visible=True,
elem_id=eid("ad_confidence"),
)
w.ad_mask_k_largest = gr.Slider(
label="Mask only the top k largest (0 to disable)" + suffix(n),
minumum=0,
maximum=5,
step=1,
value=0,
visible=True,
elem_id=eid("ad_mask_k_largest")
)
with gr.Column(variant="compact"):
w.ad_mask_min_ratio = gr.Slider(
@@ -226,6 +235,7 @@ def detection(w: Widgets, n: int, is_img2img: bool):
)
def mask_preprocessing(w: Widgets, n: int, is_img2img: bool):
eid = partial(elem_id, n=n, is_img2img=is_img2img)

View File

@@ -26,7 +26,7 @@ from adetailer import (
)
from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, EnableChecker
from adetailer.common import PredictOutput
from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes
from adetailer.mask import filter_take_largest, filter_by_ratio, mask_preprocess, sort_bboxes
from adetailer.traceback import rich_traceback
from adetailer.ui import adui, ordinal, suffix
from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models
@@ -463,6 +463,7 @@ class AfterDetailerScript(scripts.Script):
pred = filter_by_ratio(
pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio
)
pred = filter_take_largest(pred, k=args.ad_mask_k_largest)
pred = self.sort_bboxes(pred)
return mask_preprocess(
pred.masks,