mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-01-26 19:29:54 +00:00
feat: add yolo world model
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__ = "24.1.2"
|
||||
__version__ = "24.3.0.dev0"
|
||||
|
||||
@@ -42,6 +42,7 @@ class ArgsList(UserList):
|
||||
|
||||
class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
||||
ad_model: str = "None"
|
||||
ad_model_classes: str = ""
|
||||
ad_prompt: str = ""
|
||||
ad_negative_prompt: str = ""
|
||||
ad_confidence: confloat(ge=0.0, le=1.0) = 0.3
|
||||
@@ -111,6 +112,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
||||
p = {name: getattr(self, attr) for attr, name in ALL_ARGS}
|
||||
ppop = partial(self.ppop, p)
|
||||
|
||||
ppop("ADetailer model classes")
|
||||
ppop("ADetailer prompt")
|
||||
ppop("ADetailer negative prompt")
|
||||
ppop("ADetailer mask only top k largest", cond=0)
|
||||
@@ -183,6 +185,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
||||
|
||||
_all_args = [
|
||||
("ad_model", "ADetailer model"),
|
||||
("ad_model_classes", "ADetailer model classes"),
|
||||
("ad_prompt", "ADetailer prompt"),
|
||||
("ad_negative_prompt", "ADetailer negative prompt"),
|
||||
("ad_confidence", "ADetailer confidence"),
|
||||
|
||||
@@ -60,6 +60,7 @@ def get_models(
|
||||
)
|
||||
models.update(
|
||||
{
|
||||
"yolov8x-world.pt": "yolov8x-world.pt",
|
||||
"mediapipe_face_full": None,
|
||||
"mediapipe_face_short": None,
|
||||
"mediapipe_face_mesh": None,
|
||||
|
||||
@@ -85,6 +85,14 @@ def on_generate_click(state: dict, *values: Any):
|
||||
return state
|
||||
|
||||
|
||||
def on_ad_model_update(model: str):
|
||||
if "-world" in model:
|
||||
return gr.update(
|
||||
visible=True, placeholder="Comma separated class names to detect."
|
||||
)
|
||||
return gr.update(visible=False, placeholder="")
|
||||
|
||||
|
||||
def on_cn_model_update(cn_model_name: str):
|
||||
cn_model_name = cn_model_name.replace("inpaint_depth", "depth")
|
||||
for t in cn_module_choices:
|
||||
@@ -177,6 +185,20 @@ def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo):
|
||||
elem_id=eid("ad_model"),
|
||||
)
|
||||
|
||||
w.ad_model_classes = gr.Textbox(
|
||||
label="ADetailer model classes" + suffix(n),
|
||||
value="",
|
||||
visible=False,
|
||||
elem_id=eid("ad_classes"),
|
||||
)
|
||||
|
||||
w.ad_model_classes.change(
|
||||
on_ad_model_update,
|
||||
inputs=w.ad_model,
|
||||
outputs=w.ad_model_classes,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
with gr.Group():
|
||||
with gr.Row(elem_id=eid("ad_toprow_prompt")):
|
||||
w.ad_prompt = gr.Textbox(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
@@ -9,6 +10,10 @@ from torchvision.transforms.functional import to_pil_image
|
||||
from adetailer import PredictOutput
|
||||
from adetailer.common import create_mask_from_bbox
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from ultralytics import YOLO, YOLOWorld
|
||||
|
||||
|
||||
def ultralytics_predict(
|
||||
model_path: str | Path,
|
||||
@@ -39,14 +44,14 @@ def ultralytics_predict(
|
||||
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
|
||||
|
||||
|
||||
def apply_classes(model, model_path: str | Path, classes: str):
|
||||
def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str):
|
||||
if not classes or "-world" not in Path(model_path).stem:
|
||||
return
|
||||
parsed = [c.strip() for c in classes.split(",")]
|
||||
model.set_classes(parsed)
|
||||
|
||||
|
||||
def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]:
|
||||
def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -54,7 +59,7 @@ def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]:
|
||||
The device can be CUDA, but `to_pil_image` takes care of that.
|
||||
|
||||
shape: tuple[int, int]
|
||||
(width, height) of the original image
|
||||
(W, H) of the original image
|
||||
"""
|
||||
n = masks.shape[0]
|
||||
return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]
|
||||
|
||||
@@ -695,6 +695,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
predictor = ultralytics_predict
|
||||
ad_model = self.get_ad_model(args.ad_model)
|
||||
kwargs["device"] = self.ultralytics_device
|
||||
kwargs["classes"] = args.ad_model_classes
|
||||
|
||||
with change_torch_load():
|
||||
pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)
|
||||
|
||||
@@ -28,7 +28,7 @@ def test_ultralytics_hf_models(sample_image: Image.Image, model_name: str):
|
||||
|
||||
|
||||
def test_yolo_world_default(sample_image: Image.Image):
|
||||
result = ultralytics_predict("yolov8l-world.pt", sample_image)
|
||||
result = ultralytics_predict("yolov8x-world.pt", sample_image)
|
||||
assert result.preview is not None
|
||||
|
||||
|
||||
@@ -44,5 +44,5 @@ def test_yolo_world_default(sample_image: Image.Image):
|
||||
],
|
||||
)
|
||||
def test_yolo_world(sample_image2: Image.Image, klass: str):
|
||||
result = ultralytics_predict("yolov8l-world.pt", sample_image2, classes=klass)
|
||||
result = ultralytics_predict("yolov8x-world.pt", sample_image2, classes=klass)
|
||||
assert result.preview is not None
|
||||
|
||||
Reference in New Issue
Block a user