fix: predict output type, filter by ratio

This commit is contained in:
Bingsu
2023-05-18 21:09:54 +09:00
parent 27db3b7e53
commit 4e55858590
2 changed files with 30 additions and 9 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union
@@ -13,8 +13,8 @@ repo_id = "Bingsu/adetailer"
@dataclass
class PredictOutput:
bboxes: Optional[list[list[int]]] = None
masks: Optional[list[Image.Image]] = None
bboxes: list[list[float]] = field(default_factory=list)
masks: list[Image.Image] = field(default_factory=list)
preview: Optional[Image.Image] = None

View File

@@ -84,7 +84,7 @@ def bbox_area(bbox: list[float]):
def mask_preprocess(
masks: list[Image.Image] | None,
masks: list[Image.Image],
kernel: int = 0,
x_offset: int = 0,
y_offset: int = 0,
@@ -95,7 +95,7 @@ def mask_preprocess(
Parameters
----------
masks: list[Image.Image] | None
masks: list[Image.Image]
A list of masks
kernel: int
kernel size of dilation or erosion
@@ -109,14 +109,16 @@ def mask_preprocess(
list[Image.Image]
A list of processed masks
"""
if masks is None:
if not masks:
return []
masks = [dilate_erode(m, kernel) for m in masks]
masks = [m for m in masks if not is_all_black(m)]
if x_offset != 0 or y_offset != 0:
masks = [offset(m, x_offset, y_offset) for m in masks]
if kernel != 0:
masks = [dilate_erode(m, kernel) for m in masks]
masks = [m for m in masks if not is_all_black(m)]
return masks
@@ -163,7 +165,7 @@ def _key_area(bbox: list[float]) -> float:
def sort_bboxes(
pred: PredictOutput, order: int | SortBy = SortBy.NONE
) -> PredictOutput:
if order == SortBy.NONE or not pred.bboxes:
if order == SortBy.NONE or len(pred.bboxes) <= 1:
return pred
if order == SortBy.LEFT_TO_RIGHT:
@@ -182,3 +184,22 @@ def sort_bboxes(
pred.bboxes = [pred.bboxes[i] for i in idx]
pred.masks = [pred.masks[i] for i in idx]
return pred
# Filter by ratio
def is_in_ratio(bbox: list[float], low: float, high: float, orig_area: int) -> bool:
area = bbox_area(bbox)
return low <= area / orig_area <= high
def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutput:
if not pred.bboxes:
return pred
w, h = pred.preview.size
orig_area = w * h
items = len(pred.bboxes)
idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)]
pred.bboxes = [pred.bboxes[i] for i in idx]
pred.masks = [pred.masks[i] for i in idx]
return pred