feat: add yolo world model

This commit is contained in:
Dowon
2024-03-01 10:54:19 +09:00
parent 692b052cfb
commit 9e9dcd5bca
7 changed files with 38 additions and 6 deletions

View File

@@ -1 +1 @@
__version__ = "24.1.2"
__version__ = "24.3.0.dev0"

View File

@@ -42,6 +42,7 @@ class ArgsList(UserList):
class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_model: str = "None"
ad_model_classes: str = ""
ad_prompt: str = ""
ad_negative_prompt: str = ""
ad_confidence: confloat(ge=0.0, le=1.0) = 0.3
@@ -111,6 +112,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
p = {name: getattr(self, attr) for attr, name in ALL_ARGS}
ppop = partial(self.ppop, p)
ppop("ADetailer model classes")
ppop("ADetailer prompt")
ppop("ADetailer negative prompt")
ppop("ADetailer mask only top k largest", cond=0)
@@ -183,6 +185,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
_all_args = [
("ad_model", "ADetailer model"),
("ad_model_classes", "ADetailer model classes"),
("ad_prompt", "ADetailer prompt"),
("ad_negative_prompt", "ADetailer negative prompt"),
("ad_confidence", "ADetailer confidence"),

View File

@@ -60,6 +60,7 @@ def get_models(
)
models.update(
{
"yolov8x-world.pt": "yolov8x-world.pt",
"mediapipe_face_full": None,
"mediapipe_face_short": None,
"mediapipe_face_mesh": None,

View File

@@ -85,6 +85,14 @@ def on_generate_click(state: dict, *values: Any):
return state
def on_ad_model_update(model: str):
if "-world" in model:
return gr.update(
visible=True, placeholder="Comma separated class names to detect."
)
return gr.update(visible=False, placeholder="")
def on_cn_model_update(cn_model_name: str):
cn_model_name = cn_model_name.replace("inpaint_depth", "depth")
for t in cn_module_choices:
@@ -177,6 +185,20 @@ def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo):
elem_id=eid("ad_model"),
)
w.ad_model_classes = gr.Textbox(
label="ADetailer model classes" + suffix(n),
value="",
visible=False,
elem_id=eid("ad_classes"),
)
w.ad_model_classes.change(
on_ad_model_update,
inputs=w.ad_model,
outputs=w.ad_model_classes,
queue=False,
)
with gr.Group():
with gr.Row(elem_id=eid("ad_toprow_prompt")):
w.ad_prompt = gr.Textbox(

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
import cv2
from PIL import Image
@@ -9,6 +10,10 @@ from torchvision.transforms.functional import to_pil_image
from adetailer import PredictOutput
from adetailer.common import create_mask_from_bbox
if TYPE_CHECKING:
import torch
from ultralytics import YOLO, YOLOWorld
def ultralytics_predict(
model_path: str | Path,
@@ -39,14 +44,14 @@ def ultralytics_predict(
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
def apply_classes(model, model_path: str | Path, classes: str):
def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str):
if not classes or "-world" not in Path(model_path).stem:
return
parsed = [c.strip() for c in classes.split(",")]
model.set_classes(parsed)
def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]:
def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]:
"""
Parameters
----------
@@ -54,7 +59,7 @@ def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]:
The device can be CUDA, but `to_pil_image` takes care of that.
shape: tuple[int, int]
(width, height) of the original image
(W, H) of the original image
"""
n = masks.shape[0]
return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]

View File

@@ -695,6 +695,7 @@ class AfterDetailerScript(scripts.Script):
predictor = ultralytics_predict
ad_model = self.get_ad_model(args.ad_model)
kwargs["device"] = self.ultralytics_device
kwargs["classes"] = args.ad_model_classes
with change_torch_load():
pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)

View File

@@ -28,7 +28,7 @@ def test_ultralytics_hf_models(sample_image: Image.Image, model_name: str):
def test_yolo_world_default(sample_image: Image.Image):
result = ultralytics_predict("yolov8l-world.pt", sample_image)
result = ultralytics_predict("yolov8x-world.pt", sample_image)
assert result.preview is not None
@@ -44,5 +44,5 @@ def test_yolo_world_default(sample_image: Image.Image):
],
)
def test_yolo_world(sample_image2: Image.Image, klass: str):
result = ultralytics_predict("yolov8l-world.pt", sample_image2, classes=klass)
result = ultralytics_predict("yolov8x-world.pt", sample_image2, classes=klass)
assert result.preview is not None