mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-02-21 07:34:05 +00:00
test: mask_to_pil test
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from adetailer.ultralytics import ultralytics_predict
|
||||
from adetailer.ultralytics import mask_to_pil, ultralytics_predict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -60,3 +62,35 @@ def test_yolo_world(sample_image2: Image.Image, klass: str):
|
||||
assert len(result.masks) > 0
|
||||
assert len(result.confidences) > 0
|
||||
assert len(result.bboxes) == len(result.masks) == len(result.confidences)
|
||||
|
||||
|
||||
class TestMaskToPil:
|
||||
def test_mask_to_pil_float32(self):
|
||||
mask = torch.tensor([[[0.0, 1.0], [0.0, 1.0]]], dtype=torch.float32)
|
||||
imgs = mask_to_pil(mask, shape=(2, 2))
|
||||
|
||||
assert len(imgs) == 1
|
||||
img = imgs[0]
|
||||
assert isinstance(img, Image.Image)
|
||||
|
||||
arr = np.array(img)
|
||||
assert arr.shape == (2, 2)
|
||||
assert arr.dtype == np.uint8
|
||||
|
||||
expected = np.array([[0, 255], [0, 255]], dtype=np.uint8)
|
||||
np.testing.assert_array_equal(arr, expected)
|
||||
|
||||
def test_mask_to_pil_uint8(self):
|
||||
mask = torch.tensor([[[0, 1], [0, 1]]], dtype=torch.uint8)
|
||||
imgs = mask_to_pil(mask, shape=(2, 2))
|
||||
|
||||
assert len(imgs) == 1
|
||||
img = imgs[0]
|
||||
assert isinstance(img, Image.Image)
|
||||
|
||||
arr = np.array(img)
|
||||
assert arr.shape == (2, 2)
|
||||
assert arr.dtype == np.uint8
|
||||
|
||||
expected = np.array([[0, 255], [0, 255]], dtype=np.uint8)
|
||||
np.testing.assert_array_equal(arr, expected)
|
||||
|
||||
Reference in New Issue
Block a user