mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-04-25 16:59:00 +00:00
Merge branch 'dev'
This commit is contained in:
1
.github/workflows/lint.yml
vendored
1
.github/workflows/lint.yml
vendored
@@ -8,7 +8,6 @@ on:
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'Bing-su/adetailer' || github.repository == ''
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -2,6 +2,10 @@ repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: [--maxkb=100]
|
||||
- id: check-merge-conflict
|
||||
- id: check-case-conflict
|
||||
- id: check-ast
|
||||
- id: trailing-whitespace
|
||||
args: [--markdown-linebreak-ext=md]
|
||||
@@ -9,7 +13,7 @@ repos:
|
||||
- id: mixed-line-ending
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.2.2
|
||||
rev: v0.3.1
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
# Changelog
|
||||
|
||||
## 2024-03-16
|
||||
|
||||
- YOLO World v2, YOLO9 지원가능한 버전으로 ultralytics 업데이트
|
||||
- inpaint full res인 경우 인페인트 모드에서 동작하게 변경
|
||||
- inpaint full res가 아닌 경우, 사용자가 입력한 마스크와 교차점이 있는 마스크만 선택하여 사용함
|
||||
|
||||
## 2024-03-01
|
||||
|
||||
- v24.3.0
|
||||
|
||||
@@ -102,3 +102,11 @@ ADetailer works in three simple steps.
|
||||
1. Create an image.
|
||||
2. Detect object with a detection model and create a mask image.
|
||||
3. Inpaint using the image from 1 and the mask from 2.
|
||||
|
||||
## Development
|
||||
|
||||
ADetailer is developed and tested using the stable-diffusion 1.5 model, for the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) repository only.
|
||||
|
||||
## License
|
||||
|
||||
ADetailer is a derivative work that uses two AGPL-licensed works (stable-diffusion-webui, ultralytics) and is therefore distributed under the AGPL license.
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "24.3.0"
|
||||
__version__ = "24.3.1"
|
||||
|
||||
@@ -5,16 +5,28 @@ from dataclasses import dataclass
|
||||
from functools import cached_property, partial
|
||||
from typing import Any, Literal, NamedTuple, Optional
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
PositiveInt,
|
||||
confloat,
|
||||
conint,
|
||||
validator,
|
||||
)
|
||||
try:
|
||||
from pydantic.v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
PositiveInt,
|
||||
confloat,
|
||||
conint,
|
||||
validator,
|
||||
)
|
||||
except ImportError:
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
PositiveInt,
|
||||
confloat,
|
||||
conint,
|
||||
validator,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -3,11 +3,12 @@ from __future__ import annotations
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image, ImageDraw
|
||||
from rich import print
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
REPO_ID = "Bingsu/adetailer"
|
||||
_download_failed = False
|
||||
@@ -44,7 +45,7 @@ def scan_model_dir(path_: str | Path) -> list[Path]:
|
||||
|
||||
def get_models(
|
||||
model_dir: str | Path, extra_dir: str | Path = "", huggingface: bool = True
|
||||
) -> OrderedDict[str, str | None]:
|
||||
) -> OrderedDict[str, str]:
|
||||
model_paths = [*scan_model_dir(model_dir), *scan_model_dir(extra_dir)]
|
||||
|
||||
models = OrderedDict()
|
||||
@@ -56,17 +57,17 @@ def get_models(
|
||||
"hand_yolov8n.pt": hf_download("hand_yolov8n.pt"),
|
||||
"person_yolov8n-seg.pt": hf_download("person_yolov8n-seg.pt"),
|
||||
"person_yolov8s-seg.pt": hf_download("person_yolov8s-seg.pt"),
|
||||
"yolov8x-world.pt": hf_download(
|
||||
"yolov8x-world.pt", repo_id="Bingsu/yolo-world-mirror"
|
||||
"yolov8x-worldv2.pt": hf_download(
|
||||
"yolov8x-worldv2.pt", repo_id="Bingsu/yolo-world-mirror"
|
||||
),
|
||||
}
|
||||
)
|
||||
models.update(
|
||||
{
|
||||
"mediapipe_face_full": None,
|
||||
"mediapipe_face_short": None,
|
||||
"mediapipe_face_mesh": None,
|
||||
"mediapipe_face_mesh_eyes_only": None,
|
||||
"mediapipe_face_full": "mediapipe_face_full",
|
||||
"mediapipe_face_short": "mediapipe_face_short",
|
||||
"mediapipe_face_mesh": "mediapipe_face_mesh",
|
||||
"mediapipe_face_mesh_eyes_only": "mediapipe_face_mesh_eyes_only",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -133,3 +134,11 @@ def create_bbox_from_mask(
|
||||
if bbox is not None:
|
||||
bboxes.append(list(bbox))
|
||||
return bboxes
|
||||
|
||||
|
||||
def ensure_pil_image(image: Any, mode: str = "RGB") -> Image.Image:
|
||||
if not isinstance(image, Image.Image):
|
||||
image = to_pil_image(image)
|
||||
if image.mode != mode:
|
||||
image = image.convert(mode)
|
||||
return image
|
||||
|
||||
@@ -3,13 +3,14 @@ from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from functools import partial, reduce
|
||||
from math import dist
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from adetailer.args import MASK_MERGE_INVERT
|
||||
from adetailer.common import PredictOutput
|
||||
from adetailer.common import PredictOutput, ensure_pil_image
|
||||
|
||||
|
||||
class SortBy(IntEnum):
|
||||
@@ -83,12 +84,19 @@ def offset(img: Image.Image, x: int = 0, y: int = 0) -> Image.Image:
|
||||
return ImageChops.offset(img, x, -y)
|
||||
|
||||
|
||||
def is_all_black(img: Image.Image) -> bool:
|
||||
arr = np.array(img)
|
||||
return cv2.countNonZero(arr) == 0
|
||||
def is_all_black(img: Image.Image | np.ndarray) -> bool:
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
return cv2.countNonZero(img) == 0
|
||||
|
||||
|
||||
def bbox_area(bbox: list[float]):
|
||||
def has_intersection(im1: Any, im2: Any) -> bool:
|
||||
arr1 = np.array(ensure_pil_image(im1, "L"))
|
||||
arr2 = np.array(ensure_pil_image(im2, "L"))
|
||||
return not is_all_black(cv2.bitwise_and(arr1, arr2))
|
||||
|
||||
|
||||
def bbox_area(bbox: list[float]) -> float:
|
||||
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||
|
||||
|
||||
|
||||
@@ -13,8 +13,12 @@ except ImportError:
|
||||
get_cn_models,
|
||||
)
|
||||
|
||||
from .restore import CNHijackRestore, cn_allow_script_control
|
||||
|
||||
__all__ = [
|
||||
"ControlNetExt",
|
||||
"CNHijackRestore",
|
||||
"cn_allow_script_control",
|
||||
"controlnet_exists",
|
||||
"controlnet_type",
|
||||
"get_cn_models",
|
||||
|
||||
@@ -44,7 +44,7 @@ def run_pip(*args):
|
||||
def install():
|
||||
deps = [
|
||||
# requirements
|
||||
("ultralytics", "8.1.18", None),
|
||||
("ultralytics", "8.1.29", None),
|
||||
("mediapipe", "0.10.9", None),
|
||||
("rich", "13.0.0", None),
|
||||
# mediapipe
|
||||
|
||||
@@ -2,17 +2,41 @@
|
||||
name = "adetailer"
|
||||
description = "An object detection and auto-mask extension for stable diffusion webui."
|
||||
authors = [{ name = "dowon", email = "ks2515@naver.com" }]
|
||||
requires-python = ">=3.8,<3.12"
|
||||
requires-python = ">=3.8,<3.13"
|
||||
readme = "README.md"
|
||||
license = { text = "AGPL-3.0" }
|
||||
dependencies = [
|
||||
"ultralytics>=8.1",
|
||||
"mediapipe>=10",
|
||||
"pydantic<3",
|
||||
"rich>=13",
|
||||
"huggingface_hub",
|
||||
]
|
||||
keywords = [
|
||||
"stable-diffusion",
|
||||
"stable-diffusion-webui",
|
||||
"adetailer",
|
||||
"ultralytics",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
[project.urls]
|
||||
repository = "https://github.com/Bing-su/adetailer"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.version]
|
||||
path = "adetailer/__version__.py"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
known_first_party = ["launch", "modules"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py38"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"A",
|
||||
|
||||
@@ -14,9 +14,8 @@ from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageChops
|
||||
from rich import print
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
import modules
|
||||
from adetailer import (
|
||||
@@ -27,25 +26,25 @@ from adetailer import (
|
||||
ultralytics_predict,
|
||||
)
|
||||
from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, SkipImg2ImgOrig
|
||||
from adetailer.common import PredictOutput
|
||||
from adetailer.common import PredictOutput, ensure_pil_image
|
||||
from adetailer.mask import (
|
||||
filter_by_ratio,
|
||||
filter_k_largest,
|
||||
has_intersection,
|
||||
is_all_black,
|
||||
mask_preprocess,
|
||||
sort_bboxes,
|
||||
)
|
||||
from adetailer.traceback import rich_traceback
|
||||
from adetailer.ui import WebuiInfo, adui, ordinal, suffix
|
||||
from controlnet_ext import (
|
||||
CNHijackRestore,
|
||||
ControlNetExt,
|
||||
cn_allow_script_control,
|
||||
controlnet_exists,
|
||||
controlnet_type,
|
||||
get_cn_models,
|
||||
)
|
||||
from controlnet_ext.restore import (
|
||||
CNHijackRestore,
|
||||
cn_allow_script_control,
|
||||
)
|
||||
from modules import images, paths, safe, script_callbacks, scripts, shared
|
||||
from modules.devices import NansException
|
||||
from modules.processing import (
|
||||
@@ -565,27 +564,24 @@ class AfterDetailerScript(scripts.Script):
|
||||
sortby_idx = BBOX_SORTBY.index(sortby)
|
||||
return sort_bboxes(pred, sortby_idx)
|
||||
|
||||
def pred_preprocessing(self, pred: PredictOutput, args: ADetailerArgs):
|
||||
def pred_preprocessing(self, p, pred: PredictOutput, args: ADetailerArgs):
|
||||
pred = filter_by_ratio(
|
||||
pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio
|
||||
)
|
||||
pred = filter_k_largest(pred, k=args.ad_mask_k_largest)
|
||||
pred = self.sort_bboxes(pred)
|
||||
return mask_preprocess(
|
||||
masks = mask_preprocess(
|
||||
pred.masks,
|
||||
kernel=args.ad_dilate_erode,
|
||||
x_offset=args.ad_x_offset,
|
||||
y_offset=args.ad_y_offset,
|
||||
merge_invert=args.ad_mask_merge_invert,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def ensure_rgb_image(image: Any):
|
||||
if not isinstance(image, Image.Image):
|
||||
image = to_pil_image(image)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
if self.is_img2img_inpaint(p) and not self.is_inpaint_only_masked(p):
|
||||
invert = p.inpainting_mask_invert
|
||||
image_mask = ensure_pil_image(p.image_mask, mode="L")
|
||||
masks = self.inpaint_mask_filter(image_mask, masks, invert)
|
||||
return masks
|
||||
|
||||
@staticmethod
|
||||
def i2i_prompts_replace(
|
||||
@@ -637,16 +633,30 @@ class AfterDetailerScript(scripts.Script):
|
||||
|
||||
@staticmethod
|
||||
def is_img2img_inpaint(p) -> bool:
|
||||
return hasattr(p, "image_mask") and bool(p.image_mask)
|
||||
return hasattr(p, "image_mask") and p.image_mask is not None
|
||||
|
||||
@staticmethod
|
||||
def is_inpaint_only_masked(p) -> bool:
|
||||
return hasattr(p, "inpaint_full_res") and p.inpaint_full_res
|
||||
|
||||
@staticmethod
|
||||
def inpaint_mask_filter(
|
||||
img2img_mask: Image.Image, ad_mask: list[Image.Image], invert: int = 0
|
||||
) -> list[Image.Image]:
|
||||
if invert:
|
||||
img2img_mask = ImageChops.invert(img2img_mask)
|
||||
return [mask for mask in ad_mask if has_intersection(img2img_mask, mask)]
|
||||
|
||||
@rich_traceback
|
||||
def process(self, p, *args_):
|
||||
if getattr(p, "_ad_disabled", False):
|
||||
return
|
||||
|
||||
if self.is_img2img_inpaint(p):
|
||||
if self.is_img2img_inpaint(p) and is_all_black(p.image_mask):
|
||||
p._ad_disabled = True
|
||||
msg = "[-] ADetailer: img2img inpainting detected. adetailer disabled."
|
||||
msg = (
|
||||
"[-] ADetailer: img2img inpainting with no mask -- adetailer disabled."
|
||||
)
|
||||
print(msg)
|
||||
return
|
||||
|
||||
@@ -700,7 +710,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
with change_torch_load():
|
||||
pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)
|
||||
|
||||
masks = self.pred_preprocessing(pred, args)
|
||||
masks = self.pred_preprocessing(p, pred, args)
|
||||
shared.state.assign_current_image(pred.preview)
|
||||
|
||||
if not masks:
|
||||
@@ -726,7 +736,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
p2 = copy(i2i)
|
||||
for j in range(steps):
|
||||
p2.image_mask = masks[j]
|
||||
p2.init_images[0] = self.ensure_rgb_image(p2.init_images[0])
|
||||
p2.init_images[0] = ensure_pil_image(p2.init_images[0], "RGB")
|
||||
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
|
||||
|
||||
if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
|
||||
@@ -760,7 +770,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
return
|
||||
|
||||
pp.image = self.get_i2i_init_image(p, pp)
|
||||
pp.image = self.ensure_rgb_image(pp.image)
|
||||
pp.image = ensure_pil_image(pp.image, "RGB")
|
||||
init_image = copy(pp.image)
|
||||
arg_list = self.get_args(p, *args_)
|
||||
params_txt_content = Path(paths.data_path, "params.txt").read_text("utf-8")
|
||||
|
||||
69
tests/test_common.py
Normal file
69
tests/test_common.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from adetailer.common import create_bbox_from_mask, create_mask_from_bbox
|
||||
|
||||
|
||||
def test_create_mask_from_bbox():
|
||||
img = Image.new("L", (10, 10), color="black")
|
||||
bbox = [[1.0, 1.0, 2.0, 2.0], [7.0, 7.0, 8.0, 8.0]]
|
||||
masks = create_mask_from_bbox(bbox, img.size)
|
||||
expect1 = np.array(
|
||||
[
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 255, 255, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 255, 255, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
expect2 = np.array(
|
||||
[
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 255, 255, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 255, 255, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
assert len(masks) == len(bbox)
|
||||
arr1 = np.array(masks[0])
|
||||
arr2 = np.array(masks[1])
|
||||
assert arr1.shape == expect1.shape
|
||||
assert arr2.shape == expect2.shape
|
||||
assert arr1.shape == (10, 10)
|
||||
assert arr1.dtype == expect1.dtype
|
||||
assert arr2.dtype == expect2.dtype
|
||||
assert np.array_equal(arr1, expect1)
|
||||
assert np.array_equal(arr2, expect2)
|
||||
|
||||
# The function correctly receives a list of masks and the shape of the image.
|
||||
|
||||
|
||||
def test_create_bbox_from_mask():
|
||||
mask = Image.new("L", (10, 10), color="black")
|
||||
draw = ImageDraw.Draw(mask)
|
||||
draw.rectangle((2, 2, 5, 5), fill="white")
|
||||
|
||||
result = create_bbox_from_mask([mask], (10, 10))
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert all(isinstance(bbox, list) for bbox in result)
|
||||
assert all(len(bbox) == 4 for bbox in result)
|
||||
assert result[0] == [2, 2, 6, 6]
|
||||
|
||||
result = create_bbox_from_mask([mask], (256, 256))
|
||||
assert result[0] == [38, 38, 166, 166]
|
||||
154
tests/test_mask.py
Normal file
154
tests/test_mask.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from adetailer.mask import dilate_erode, has_intersection, is_all_black, offset
|
||||
|
||||
|
||||
def test_dilate_positive_value():
|
||||
img = Image.new("L", (10, 10), color="black")
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.rectangle((3, 3, 5, 5), fill="white")
|
||||
value = 3
|
||||
|
||||
result = dilate_erode(img, value)
|
||||
|
||||
assert isinstance(result, Image.Image)
|
||||
assert result.size == (10, 10)
|
||||
|
||||
expect = np.array(
|
||||
[
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 255, 255, 255, 255, 255, 0, 0, 0],
|
||||
[0, 0, 255, 255, 255, 255, 255, 0, 0, 0],
|
||||
[0, 0, 255, 255, 255, 255, 255, 0, 0, 0],
|
||||
[0, 0, 255, 255, 255, 255, 255, 0, 0, 0],
|
||||
[0, 0, 255, 255, 255, 255, 255, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
assert np.array_equal(np.array(result), expect)
|
||||
|
||||
|
||||
def test_offset():
|
||||
img = Image.new("L", (10, 10), color="black")
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.rectangle((4, 4, 5, 5), fill="white")
|
||||
|
||||
result = offset(img, x=1, y=2)
|
||||
|
||||
assert isinstance(result, Image.Image)
|
||||
assert result.size == (10, 10)
|
||||
|
||||
expect = np.array(
|
||||
[
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 255, 255, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 255, 255, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
assert np.array_equal(np.array(result), expect)
|
||||
|
||||
|
||||
def test_is_all_black_1():
|
||||
img = Image.new("L", (10, 10), color="black")
|
||||
assert is_all_black(img)
|
||||
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.rectangle((4, 4, 5, 5), fill="white")
|
||||
assert not is_all_black(img)
|
||||
|
||||
|
||||
def test_is_all_black_2():
|
||||
img = np.zeros((10, 10), dtype=np.uint8)
|
||||
assert is_all_black(img)
|
||||
|
||||
img[4:6, 4:6] = 255
|
||||
assert not is_all_black(img)
|
||||
|
||||
|
||||
def test_has_intersection_1():
|
||||
arr1 = np.array(
|
||||
[
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
]
|
||||
)
|
||||
arr2 = arr1.copy()
|
||||
assert not has_intersection(arr1, arr2)
|
||||
|
||||
|
||||
def test_has_intersection_2():
|
||||
arr1 = np.array(
|
||||
[
|
||||
[0, 0, 0, 0],
|
||||
[0, 255, 255, 0],
|
||||
[0, 255, 255, 0],
|
||||
[0, 0, 0, 0],
|
||||
]
|
||||
)
|
||||
arr2 = np.array(
|
||||
[
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 255, 255],
|
||||
[0, 0, 255, 255],
|
||||
]
|
||||
)
|
||||
assert has_intersection(arr1, arr2)
|
||||
|
||||
arr3 = np.array(
|
||||
[
|
||||
[255, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 255],
|
||||
[0, 0, 255, 255],
|
||||
]
|
||||
)
|
||||
assert not has_intersection(arr1, arr3)
|
||||
|
||||
|
||||
def test_has_intersection_3():
|
||||
img1 = Image.new("L", (10, 10), color="black")
|
||||
draw1 = ImageDraw.Draw(img1)
|
||||
draw1.rectangle((3, 3, 5, 5), fill="white")
|
||||
img2 = Image.new("L", (10, 10), color="black")
|
||||
draw2 = ImageDraw.Draw(img2)
|
||||
draw2.rectangle((6, 6, 8, 8), fill="white")
|
||||
assert not has_intersection(img1, img2)
|
||||
|
||||
img3 = Image.new("L", (10, 10), color="black")
|
||||
draw3 = ImageDraw.Draw(img3)
|
||||
draw3.rectangle((2, 2, 8, 8), fill="white")
|
||||
assert has_intersection(img1, img3)
|
||||
|
||||
|
||||
def test_has_intersection_4():
|
||||
img1 = Image.new("RGB", (10, 10), color="black")
|
||||
draw1 = ImageDraw.Draw(img1)
|
||||
draw1.rectangle((3, 3, 5, 5), fill="white")
|
||||
img2 = Image.new("RGBA", (10, 10), color="black")
|
||||
draw2 = ImageDraw.Draw(img2)
|
||||
draw2.rectangle((2, 2, 8, 8), fill="white")
|
||||
assert has_intersection(img1, img2)
|
||||
|
||||
|
||||
def test_has_intersection_5():
|
||||
img1 = Image.new("RGB", (10, 10), color="black")
|
||||
draw1 = ImageDraw.Draw(img1)
|
||||
draw1.rectangle((4, 4, 5, 5), fill="white")
|
||||
img2 = np.full((10, 10, 4), 255, dtype=np.uint8)
|
||||
assert has_intersection(img1, img2)
|
||||
@@ -26,7 +26,7 @@ def test_ultralytics_hf_models(sample_image: Image.Image, model_name: str):
|
||||
|
||||
|
||||
def test_yolo_world_default(sample_image: Image.Image):
|
||||
model_path = hf_hub_download("Bingsu/yolo-world-mirror", "yolov8x-world.pt")
|
||||
model_path = hf_hub_download("Bingsu/yolo-world-mirror", "yolov8x-worldv2.pt")
|
||||
result = ultralytics_predict(model_path, sample_image)
|
||||
assert result.preview is not None
|
||||
|
||||
@@ -43,6 +43,6 @@ def test_yolo_world_default(sample_image: Image.Image):
|
||||
],
|
||||
)
|
||||
def test_yolo_world(sample_image2: Image.Image, klass: str):
|
||||
model_path = hf_hub_download("Bingsu/yolo-world-mirror", "yolov8x-world.pt")
|
||||
model_path = hf_hub_download("Bingsu/yolo-world-mirror", "yolov8x-worldv2.pt")
|
||||
result = ultralytics_predict(model_path, sample_image2, classes=klass)
|
||||
assert result.preview is not None
|
||||
|
||||
Reference in New Issue
Block a user