Feature: Filter k most confident masks (#720)

* add filter to things that return confidences need to add ui elements to select between the two methods

* add ui elements for controlling method

* forgot to remove this

* fix incorrect early exit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: pop mask only top k params

* fix: filter confidences

* refactor: change to one public function

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Dowon <ks2515@naver.com>
This commit is contained in:
Collin Avidano
2024-09-25 07:12:28 -04:00
committed by GitHub
parent 03ec9d004a
commit 9ceb58685a
7 changed files with 56 additions and 11 deletions

View File

@@ -294,14 +294,21 @@ def detection(w: Widgets, n: int, is_img2img: bool):
visible=True, visible=True,
elem_id=eid("ad_confidence"), elem_id=eid("ad_confidence"),
) )
w.ad_mask_k_largest = gr.Slider( w.ad_mask_filter_method = gr.Radio(
label="Mask only the top k largest (0 to disable)" + suffix(n), choices=["Area", "Confidence"],
value="Area",
label="Method to filter top k masks by (confidence or area)",
visible=True,
elem_id=eid("ad_mask_filter_method"),
)
w.ad_mask_k = gr.Slider(
label="Mask only the top k (0 to disable)" + suffix(n),
minimum=0, minimum=0,
maximum=10, maximum=10,
step=1, step=1,
value=0, value=0,
visible=True, visible=True,
elem_id=eid("ad_mask_k_largest"), elem_id=eid("ad_mask_k"),
) )
with gr.Column(variant="compact"): with gr.Column(variant="compact"):

View File

@@ -60,7 +60,8 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_prompt: str = "" ad_prompt: str = ""
ad_negative_prompt: str = "" ad_negative_prompt: str = ""
ad_confidence: confloat(ge=0.0, le=1.0) = 0.3 ad_confidence: confloat(ge=0.0, le=1.0) = 0.3
ad_mask_k_largest: NonNegativeInt = 0 ad_mask_filter_method: Literal["Area", "Confidence"] = "Area"
ad_mask_k: NonNegativeInt = 0
ad_mask_min_ratio: confloat(ge=0.0, le=1.0) = 0.0 ad_mask_min_ratio: confloat(ge=0.0, le=1.0) = 0.0
ad_mask_max_ratio: confloat(ge=0.0, le=1.0) = 1.0 ad_mask_max_ratio: confloat(ge=0.0, le=1.0) = 1.0
ad_dilate_erode: int = 4 ad_dilate_erode: int = 4
@@ -131,7 +132,11 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
ppop("ADetailer prompt") ppop("ADetailer prompt")
ppop("ADetailer negative prompt") ppop("ADetailer negative prompt")
p.pop("ADetailer tab enable", None) # always pop p.pop("ADetailer tab enable", None) # always pop
ppop("ADetailer mask only top k largest", cond=0) ppop(
"ADetailer mask only top k",
["ADetailer mask only top k", "ADetailer method to decide top k masks"],
cond=0,
)
ppop("ADetailer mask min ratio", cond=0.0) ppop("ADetailer mask min ratio", cond=0.0)
ppop("ADetailer mask max ratio", cond=1.0) ppop("ADetailer mask max ratio", cond=1.0)
ppop("ADetailer x offset", cond=0) ppop("ADetailer x offset", cond=0)
@@ -217,7 +222,8 @@ _all_args = [
("ad_prompt", "ADetailer prompt"), ("ad_prompt", "ADetailer prompt"),
("ad_negative_prompt", "ADetailer negative prompt"), ("ad_negative_prompt", "ADetailer negative prompt"),
("ad_confidence", "ADetailer confidence"), ("ad_confidence", "ADetailer confidence"),
("ad_mask_k_largest", "ADetailer mask only top k largest"), ("ad_mask_filter_method", "ADetailer method to decide top k masks"),
("ad_mask_k", "ADetailer mask only top k"),
("ad_mask_min_ratio", "ADetailer mask min ratio"), ("ad_mask_min_ratio", "ADetailer mask min ratio"),
("ad_mask_max_ratio", "ADetailer mask max ratio"), ("ad_mask_max_ratio", "ADetailer mask max ratio"),
("ad_x_offset", "ADetailer x offset"), ("ad_x_offset", "ADetailer x offset"),

View File

@@ -22,6 +22,7 @@ T = TypeVar("T", int, float)
class PredictOutput(Generic[T]): class PredictOutput(Generic[T]):
bboxes: list[list[T]] = field(default_factory=list) bboxes: list[list[T]] = field(default_factory=list)
masks: list[Image.Image] = field(default_factory=list) masks: list[Image.Image] = field(default_factory=list)
confidences: list[float] = field(default_factory=list)
preview: Optional[Image.Image] = None preview: Optional[Image.Image] = None

View File

@@ -225,6 +225,7 @@ def filter_by_ratio(
idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)] idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)]
pred.bboxes = [pred.bboxes[i] for i in idx] pred.bboxes = [pred.bboxes[i] for i in idx]
pred.masks = [pred.masks[i] for i in idx] pred.masks = [pred.masks[i] for i in idx]
pred.confidences = [pred.confidences[i] for i in idx]
return pred return pred
@@ -236,9 +237,31 @@ def filter_k_largest(pred: PredictOutput[T], k: int = 0) -> PredictOutput[T]:
idx = idx[::-1] idx = idx[::-1]
pred.bboxes = [pred.bboxes[i] for i in idx] pred.bboxes = [pred.bboxes[i] for i in idx]
pred.masks = [pred.masks[i] for i in idx] pred.masks = [pred.masks[i] for i in idx]
pred.confidences = [pred.confidences[i] for i in idx]
return pred return pred
def filter_k_most_confident(pred: PredictOutput[T], k: int = 0) -> PredictOutput[T]:
if not pred.bboxes or not pred.confidences or k == 0:
return pred
idx = np.argsort(pred.confidences)[-k:]
idx = idx[::-1]
pred.bboxes = [pred.bboxes[i] for i in idx]
pred.masks = [pred.masks[i] for i in idx]
pred.confidences = [pred.confidences[i] for i in idx]
return pred
def filter_k_by(
pred: PredictOutput[T], k: int = 0, by: str = "Area"
) -> PredictOutput[T]:
if by == "Area":
return filter_k_largest(pred, k)
if by == "Confidence":
return filter_k_most_confident(pred, k)
raise RuntimeError
# Merge / Invert # Merge / Invert
def mask_merge(masks: list[Image.Image]) -> list[Image.Image]: def mask_merge(masks: list[Image.Image]) -> list[Image.Image]:
arrs = [np.array(m) for m in masks] arrs = [np.array(m) for m in masks]

View File

@@ -52,6 +52,7 @@ def mediapipe_face_detection(
preview_array = img_array.copy() preview_array = img_array.copy()
bboxes = [] bboxes = []
confidences = []
for detection in pred.detections: for detection in pred.detections:
draw_util.draw_detection(preview_array, detection) draw_util.draw_detection(preview_array, detection)
@@ -63,12 +64,15 @@ def mediapipe_face_detection(
x2 = x1 + w x2 = x1 + w
y2 = y1 + h y2 = y1 + h
confidences.append(detection.score)
bboxes.append([x1, y1, x2, y2]) bboxes.append([x1, y1, x2, y2])
masks = create_mask_from_bbox(bboxes, image.size) masks = create_mask_from_bbox(bboxes, image.size)
preview = Image.fromarray(preview_array) preview = Image.fromarray(preview_array)
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview) return PredictOutput(
bboxes=bboxes, masks=masks, confidences=confidences, preview=preview
)
def mediapipe_face_mesh( def mediapipe_face_mesh(
@@ -141,7 +145,6 @@ def mediapipe_face_mesh_eyes_only(
preview = image.copy() preview = image.copy()
masks = [] masks = []
for landmarks in pred.multi_face_landmarks: for landmarks in pred.multi_face_landmarks:
points = np.array( points = np.array(
[[land.x * w, land.y * h] for land in landmarks.landmark], dtype=int [[land.x * w, land.y * h] for land in landmarks.landmark], dtype=int

View File

@@ -37,11 +37,16 @@ def ultralytics_predict(
masks = create_mask_from_bbox(bboxes, image.size) masks = create_mask_from_bbox(bboxes, image.size)
else: else:
masks = mask_to_pil(pred[0].masks.data, image.size) masks = mask_to_pil(pred[0].masks.data, image.size)
confidences = pred[0].boxes.conf.cpu().numpy().tolist()
preview = pred[0].plot() preview = pred[0].plot()
preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB) preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB)
preview = Image.fromarray(preview) preview = Image.fromarray(preview)
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview) return PredictOutput(
bboxes=bboxes, masks=masks, confidences=confidences, preview=preview
)
def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str): def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str):

View File

@@ -52,7 +52,7 @@ from adetailer.args import (
from adetailer.common import PredictOutput, ensure_pil_image, safe_mkdir from adetailer.common import PredictOutput, ensure_pil_image, safe_mkdir
from adetailer.mask import ( from adetailer.mask import (
filter_by_ratio, filter_by_ratio,
filter_k_largest, filter_k_by,
has_intersection, has_intersection,
is_all_black, is_all_black,
mask_preprocess, mask_preprocess,
@@ -596,7 +596,7 @@ class AfterDetailerScript(scripts.Script):
pred = filter_by_ratio( pred = filter_by_ratio(
pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio
) )
pred = filter_k_largest(pred, k=args.ad_mask_k_largest) pred = filter_k_by(pred, k=args.ad_mask_k, by=args.ad_mask_filter_method)
pred = self.sort_bboxes(pred) pred = self.sort_bboxes(pred)
masks = mask_preprocess( masks = mask_preprocess(
pred.masks, pred.masks,