mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-01-26 19:29:54 +00:00
feat: add sorting bboxes option
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -20,6 +21,12 @@ class PredictOutput:
|
||||
preview: Optional[Image.Image] = None
|
||||
|
||||
|
||||
class SortBy(IntEnum):
|
||||
NONE = 0
|
||||
POSITION = 1
|
||||
AREA = 2
|
||||
|
||||
|
||||
def get_models(
|
||||
model_dir: Union[str, Path], huggingface: bool = True
|
||||
) -> OrderedDict[str, Optional[str]]:
|
||||
@@ -190,3 +197,42 @@ def mask_preprocess(
|
||||
masks = [offset(m, x_offset, y_offset) for m in masks]
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
def _key_position(bbox: list[float]) -> float:
|
||||
"""
|
||||
Left to right
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox: list[float]
|
||||
list of [x1, y1, x2, y2]
|
||||
"""
|
||||
return bbox[0]
|
||||
|
||||
|
||||
def _key_area(bbox: list[float]) -> float:
|
||||
"""
|
||||
Large to small
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox: list[float]
|
||||
list of [x1, y1, x2, y2]
|
||||
"""
|
||||
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||
return -area
|
||||
|
||||
|
||||
def sort_bboxes(
|
||||
pred: PredictOutput, order: int | SortBy = SortBy.NONE
|
||||
) -> PredictOutput:
|
||||
if order == SortBy.NONE or not pred.bboxes:
|
||||
return pred
|
||||
|
||||
items = len(pred.bboxes)
|
||||
key = _key_area if order == SortBy.AREA else _key_position
|
||||
idx = sorted(range(items), key=lambda i: key(pred.bboxes[i]))
|
||||
pred.bboxes = [pred.bboxes[i] for i in idx]
|
||||
pred.masks = [pred.masks[i] for i in idx]
|
||||
return pred
|
||||
|
||||
@@ -24,7 +24,7 @@ from adetailer import (
|
||||
mediapipe_predict,
|
||||
ultralytics_predict,
|
||||
)
|
||||
from adetailer.common import mask_preprocess
|
||||
from adetailer.common import PredictOutput, mask_preprocess, sort_bboxes
|
||||
from adetailer.ui import adui, ordinal, suffix
|
||||
from controlnet_ext import ControlNetExt, controlnet_exists
|
||||
from sd_webui import images, safe, script_callbacks, scripts, shared
|
||||
@@ -378,6 +378,11 @@ class AfterDetailerScript(scripts.Script):
|
||||
raise ValueError(msg)
|
||||
return model_mapping[name]
|
||||
|
||||
def sort_bboxes(self, pred: PredictOutput) -> PredictOutput:
|
||||
sortby = opts.data.get("ad_bbox_sortby", 2)
|
||||
pred = sort_bboxes(pred, sortby)
|
||||
return pred
|
||||
|
||||
def i2i_prompts_replace(
|
||||
self, i2i, prompts: list[str], negative_prompts: list[str], j: int
|
||||
):
|
||||
@@ -425,6 +430,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
with ChangeTorchLoad():
|
||||
pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs)
|
||||
|
||||
pred = self.sort_bboxes(pred)
|
||||
masks = mask_preprocess(
|
||||
pred.masks,
|
||||
kernel=args.ad_dilate_erode,
|
||||
@@ -556,6 +562,24 @@ def on_ui_settings():
|
||||
),
|
||||
)
|
||||
|
||||
bbox_sortby = ["None", "Position (left to right)", "Area (large to small)"]
|
||||
bbox_sortby_args = {
|
||||
"choices": bbox_sortby,
|
||||
"type": "index",
|
||||
"interactive": True,
|
||||
}
|
||||
|
||||
shared.opts.add_option(
|
||||
"ad_bbox_sortby",
|
||||
shared.OptionInfo(
|
||||
default=0,
|
||||
label="Sort bounding boxes by",
|
||||
component=gr.Radio,
|
||||
component_args=bbox_sortby_args,
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
script_callbacks.on_after_component(on_after_component)
|
||||
|
||||
Reference in New Issue
Block a user