diff --git a/adetailer/common.py b/adetailer/common.py index 37c0320..1a1fd16 100644 --- a/adetailer/common.py +++ b/adetailer/common.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections import OrderedDict from dataclasses import dataclass +from enum import IntEnum from pathlib import Path from typing import Optional, Union @@ -20,6 +21,12 @@ class PredictOutput: preview: Optional[Image.Image] = None +class SortBy(IntEnum): + NONE = 0 + POSITION = 1 + AREA = 2 + + def get_models( model_dir: Union[str, Path], huggingface: bool = True ) -> OrderedDict[str, Optional[str]]: @@ -190,3 +197,42 @@ def mask_preprocess( masks = [offset(m, x_offset, y_offset) for m in masks] return masks + + +def _key_position(bbox: list[float]) -> float: + """ + Left to right + + Parameters + ---------- + bbox: list[float] + list of [x1, y1, x2, y2] + """ + return bbox[0] + + +def _key_area(bbox: list[float]) -> float: + """ + Large to small + + Parameters + ---------- + bbox: list[float] + list of [x1, y1, x2, y2] + """ + area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + return -area + + +def sort_bboxes( + pred: PredictOutput, order: int | SortBy = SortBy.NONE +) -> PredictOutput: + if order == SortBy.NONE or not pred.bboxes: + return pred + + items = len(pred.bboxes) + key = _key_area if order == SortBy.AREA else _key_position + idx = sorted(range(items), key=lambda i: key(pred.bboxes[i])) + pred.bboxes = [pred.bboxes[i] for i in idx] + pred.masks = [pred.masks[i] for i in idx] + return pred diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index e3d007a..10682f4 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -24,7 +24,7 @@ from adetailer import ( mediapipe_predict, ultralytics_predict, ) -from adetailer.common import mask_preprocess +from adetailer.common import PredictOutput, mask_preprocess, sort_bboxes from adetailer.ui import adui, ordinal, suffix from controlnet_ext import ControlNetExt, controlnet_exists from sd_webui import images, safe, script_callbacks, scripts, shared @@ -378,6 +378,11 @@ class AfterDetailerScript(scripts.Script): raise ValueError(msg) return model_mapping[name] + def sort_bboxes(self, pred: PredictOutput) -> PredictOutput: + sortby = opts.data.get("ad_bbox_sortby", 2) + pred = sort_bboxes(pred, sortby) + return pred + def i2i_prompts_replace( self, i2i, prompts: list[str], negative_prompts: list[str], j: int ): @@ -425,6 +430,7 @@ class AfterDetailerScript(scripts.Script): with ChangeTorchLoad(): pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs) + pred = self.sort_bboxes(pred) masks = mask_preprocess( pred.masks, kernel=args.ad_dilate_erode, @@ -556,6 +562,24 @@ def on_ui_settings(): ), ) + bbox_sortby = ["None", "Position (left to right)", "Area (large to small)"] + bbox_sortby_args = { + "choices": bbox_sortby, + "type": "index", + "interactive": True, + } + + shared.opts.add_option( + "ad_bbox_sortby", + shared.OptionInfo( + default=0, + label="Sort bounding boxes by", + component=gr.Radio, + component_args=bbox_sortby_args, + section=section, + ), + ) + script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_after_component(on_after_component)