mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-02-09 18:00:05 +00:00
fix: predict output type, filter by ratio
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -13,8 +13,8 @@ repo_id = "Bingsu/adetailer"
|
||||
|
||||
@dataclass
|
||||
class PredictOutput:
|
||||
bboxes: Optional[list[list[int]]] = None
|
||||
masks: Optional[list[Image.Image]] = None
|
||||
bboxes: list[list[float]] = field(default_factory=list)
|
||||
masks: list[Image.Image] = field(default_factory=list)
|
||||
preview: Optional[Image.Image] = None
|
||||
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ def bbox_area(bbox: list[float]):
|
||||
|
||||
|
||||
def mask_preprocess(
|
||||
masks: list[Image.Image] | None,
|
||||
masks: list[Image.Image],
|
||||
kernel: int = 0,
|
||||
x_offset: int = 0,
|
||||
y_offset: int = 0,
|
||||
@@ -95,7 +95,7 @@ def mask_preprocess(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
masks: list[Image.Image] | None
|
||||
masks: list[Image.Image]
|
||||
A list of masks
|
||||
kernel: int
|
||||
kernel size of dilation or erosion
|
||||
@@ -109,14 +109,16 @@ def mask_preprocess(
|
||||
list[Image.Image]
|
||||
A list of processed masks
|
||||
"""
|
||||
if masks is None:
|
||||
if not masks:
|
||||
return []
|
||||
|
||||
masks = [dilate_erode(m, kernel) for m in masks]
|
||||
masks = [m for m in masks if not is_all_black(m)]
|
||||
if x_offset != 0 or y_offset != 0:
|
||||
masks = [offset(m, x_offset, y_offset) for m in masks]
|
||||
|
||||
if kernel != 0:
|
||||
masks = [dilate_erode(m, kernel) for m in masks]
|
||||
masks = [m for m in masks if not is_all_black(m)]
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
@@ -163,7 +165,7 @@ def _key_area(bbox: list[float]) -> float:
|
||||
def sort_bboxes(
|
||||
pred: PredictOutput, order: int | SortBy = SortBy.NONE
|
||||
) -> PredictOutput:
|
||||
if order == SortBy.NONE or not pred.bboxes:
|
||||
if order == SortBy.NONE or len(pred.bboxes) <= 1:
|
||||
return pred
|
||||
|
||||
if order == SortBy.LEFT_TO_RIGHT:
|
||||
@@ -182,3 +184,22 @@ def sort_bboxes(
|
||||
pred.bboxes = [pred.bboxes[i] for i in idx]
|
||||
pred.masks = [pred.masks[i] for i in idx]
|
||||
return pred
|
||||
|
||||
|
||||
# Filter by ratio
|
||||
def is_in_ratio(bbox: list[float], low: float, high: float, orig_area: int) -> bool:
|
||||
area = bbox_area(bbox)
|
||||
return low <= area / orig_area <= high
|
||||
|
||||
|
||||
def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutput:
|
||||
if not pred.bboxes:
|
||||
return pred
|
||||
|
||||
w, h = pred.preview.size
|
||||
orig_area = w * h
|
||||
items = len(pred.bboxes)
|
||||
idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)]
|
||||
pred.bboxes = [pred.bboxes[i] for i in idx]
|
||||
pred.masks = [pred.masks[i] for i in idx]
|
||||
return pred
|
||||
|
||||
Reference in New Issue
Block a user