From 4e5585859005dc0a4233e2de6b734c2ed18a2b44 Mon Sep 17 00:00:00 2001 From: Bingsu Date: Thu, 18 May 2023 21:09:54 +0900 Subject: [PATCH] fix: predict output type, filter by ratio --- adetailer/common.py | 6 +++--- adetailer/mask.py | 33 +++++++++++++++++++++++++++------ 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/adetailer/common.py b/adetailer/common.py index bd4f10b..d52e568 100644 --- a/adetailer/common.py +++ b/adetailer/common.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import OrderedDict -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Optional, Union @@ -13,8 +13,8 @@ repo_id = "Bingsu/adetailer" @dataclass class PredictOutput: - bboxes: Optional[list[list[int]]] = None - masks: Optional[list[Image.Image]] = None + bboxes: list[list[float]] = field(default_factory=list) + masks: list[Image.Image] = field(default_factory=list) preview: Optional[Image.Image] = None diff --git a/adetailer/mask.py b/adetailer/mask.py index 53faaab..3860a65 100644 --- a/adetailer/mask.py +++ b/adetailer/mask.py @@ -84,7 +84,7 @@ def bbox_area(bbox: list[float]): def mask_preprocess( - masks: list[Image.Image] | None, + masks: list[Image.Image], kernel: int = 0, x_offset: int = 0, y_offset: int = 0, @@ -95,7 +95,7 @@ def mask_preprocess( Parameters ---------- - masks: list[Image.Image] | None + masks: list[Image.Image] A list of masks kernel: int kernel size of dilation or erosion @@ -109,14 +109,16 @@ def mask_preprocess( list[Image.Image] A list of processed masks """ - if masks is None: + if not masks: return [] - masks = [dilate_erode(m, kernel) for m in masks] - masks = [m for m in masks if not is_all_black(m)] if x_offset != 0 or y_offset != 0: masks = [offset(m, x_offset, y_offset) for m in masks] + if kernel != 0: + masks = [dilate_erode(m, kernel) for m in masks] + masks = [m for m in masks if not is_all_black(m)] + return masks @@ -163,7 +165,7 @@ def _key_area(bbox: list[float]) -> float: def sort_bboxes( pred: PredictOutput, order: int | SortBy = SortBy.NONE ) -> PredictOutput: - if order == SortBy.NONE or not pred.bboxes: + if order == SortBy.NONE or len(pred.bboxes) <= 1: return pred if order == SortBy.LEFT_TO_RIGHT: @@ -182,3 +184,22 @@ def sort_bboxes( pred.bboxes = [pred.bboxes[i] for i in idx] pred.masks = [pred.masks[i] for i in idx] return pred + + +# Filter by ratio +def is_in_ratio(bbox: list[float], low: float, high: float, orig_area: int) -> bool: + area = bbox_area(bbox) + return low <= area / orig_area <= high + + +def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutput: + if not pred.bboxes: + return pred + + w, h = pred.preview.size + orig_area = w * h + items = len(pred.bboxes) + idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)] + pred.bboxes = [pred.bboxes[i] for i in idx] + pred.masks = [pred.masks[i] for i in idx] + return pred