mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-04-30 19:21:33 +00:00
feat: add yolo world model
This commit is contained in:
@@ -1 +1 @@
|
|||||||
__version__ = "24.1.2"
|
__version__ = "24.3.0.dev0"
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class ArgsList(UserList):
|
|||||||
|
|
||||||
class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
||||||
ad_model: str = "None"
|
ad_model: str = "None"
|
||||||
|
ad_model_classes: str = ""
|
||||||
ad_prompt: str = ""
|
ad_prompt: str = ""
|
||||||
ad_negative_prompt: str = ""
|
ad_negative_prompt: str = ""
|
||||||
ad_confidence: confloat(ge=0.0, le=1.0) = 0.3
|
ad_confidence: confloat(ge=0.0, le=1.0) = 0.3
|
||||||
@@ -111,6 +112,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
|||||||
p = {name: getattr(self, attr) for attr, name in ALL_ARGS}
|
p = {name: getattr(self, attr) for attr, name in ALL_ARGS}
|
||||||
ppop = partial(self.ppop, p)
|
ppop = partial(self.ppop, p)
|
||||||
|
|
||||||
|
ppop("ADetailer model classes")
|
||||||
ppop("ADetailer prompt")
|
ppop("ADetailer prompt")
|
||||||
ppop("ADetailer negative prompt")
|
ppop("ADetailer negative prompt")
|
||||||
ppop("ADetailer mask only top k largest", cond=0)
|
ppop("ADetailer mask only top k largest", cond=0)
|
||||||
@@ -183,6 +185,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
|||||||
|
|
||||||
_all_args = [
|
_all_args = [
|
||||||
("ad_model", "ADetailer model"),
|
("ad_model", "ADetailer model"),
|
||||||
|
("ad_model_classes", "ADetailer model classes"),
|
||||||
("ad_prompt", "ADetailer prompt"),
|
("ad_prompt", "ADetailer prompt"),
|
||||||
("ad_negative_prompt", "ADetailer negative prompt"),
|
("ad_negative_prompt", "ADetailer negative prompt"),
|
||||||
("ad_confidence", "ADetailer confidence"),
|
("ad_confidence", "ADetailer confidence"),
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ def get_models(
|
|||||||
)
|
)
|
||||||
models.update(
|
models.update(
|
||||||
{
|
{
|
||||||
|
"yolov8x-world.pt": "yolov8x-world.pt",
|
||||||
"mediapipe_face_full": None,
|
"mediapipe_face_full": None,
|
||||||
"mediapipe_face_short": None,
|
"mediapipe_face_short": None,
|
||||||
"mediapipe_face_mesh": None,
|
"mediapipe_face_mesh": None,
|
||||||
|
|||||||
@@ -85,6 +85,14 @@ def on_generate_click(state: dict, *values: Any):
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def on_ad_model_update(model: str):
|
||||||
|
if "-world" in model:
|
||||||
|
return gr.update(
|
||||||
|
visible=True, placeholder="Comma separated class names to detect."
|
||||||
|
)
|
||||||
|
return gr.update(visible=False, placeholder="")
|
||||||
|
|
||||||
|
|
||||||
def on_cn_model_update(cn_model_name: str):
|
def on_cn_model_update(cn_model_name: str):
|
||||||
cn_model_name = cn_model_name.replace("inpaint_depth", "depth")
|
cn_model_name = cn_model_name.replace("inpaint_depth", "depth")
|
||||||
for t in cn_module_choices:
|
for t in cn_module_choices:
|
||||||
@@ -177,6 +185,20 @@ def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo):
|
|||||||
elem_id=eid("ad_model"),
|
elem_id=eid("ad_model"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
w.ad_model_classes = gr.Textbox(
|
||||||
|
label="ADetailer model classes" + suffix(n),
|
||||||
|
value="",
|
||||||
|
visible=False,
|
||||||
|
elem_id=eid("ad_classes"),
|
||||||
|
)
|
||||||
|
|
||||||
|
w.ad_model_classes.change(
|
||||||
|
on_ad_model_update,
|
||||||
|
inputs=w.ad_model,
|
||||||
|
outputs=w.ad_model_classes,
|
||||||
|
queue=False,
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
with gr.Row(elem_id=eid("ad_toprow_prompt")):
|
with gr.Row(elem_id=eid("ad_toprow_prompt")):
|
||||||
w.ad_prompt = gr.Textbox(
|
w.ad_prompt = gr.Textbox(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -9,6 +10,10 @@ from torchvision.transforms.functional import to_pil_image
|
|||||||
from adetailer import PredictOutput
|
from adetailer import PredictOutput
|
||||||
from adetailer.common import create_mask_from_bbox
|
from adetailer.common import create_mask_from_bbox
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
from ultralytics import YOLO, YOLOWorld
|
||||||
|
|
||||||
|
|
||||||
def ultralytics_predict(
|
def ultralytics_predict(
|
||||||
model_path: str | Path,
|
model_path: str | Path,
|
||||||
@@ -39,14 +44,14 @@ def ultralytics_predict(
|
|||||||
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
|
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
|
||||||
|
|
||||||
|
|
||||||
def apply_classes(model, model_path: str | Path, classes: str):
|
def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str):
|
||||||
if not classes or "-world" not in Path(model_path).stem:
|
if not classes or "-world" not in Path(model_path).stem:
|
||||||
return
|
return
|
||||||
parsed = [c.strip() for c in classes.split(",")]
|
parsed = [c.strip() for c in classes.split(",")]
|
||||||
model.set_classes(parsed)
|
model.set_classes(parsed)
|
||||||
|
|
||||||
|
|
||||||
def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]:
|
def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]:
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -54,7 +59,7 @@ def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]:
|
|||||||
The device can be CUDA, but `to_pil_image` takes care of that.
|
The device can be CUDA, but `to_pil_image` takes care of that.
|
||||||
|
|
||||||
shape: tuple[int, int]
|
shape: tuple[int, int]
|
||||||
(width, height) of the original image
|
(W, H) of the original image
|
||||||
"""
|
"""
|
||||||
n = masks.shape[0]
|
n = masks.shape[0]
|
||||||
return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]
|
return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]
|
||||||
|
|||||||
@@ -695,6 +695,7 @@ class AfterDetailerScript(scripts.Script):
|
|||||||
predictor = ultralytics_predict
|
predictor = ultralytics_predict
|
||||||
ad_model = self.get_ad_model(args.ad_model)
|
ad_model = self.get_ad_model(args.ad_model)
|
||||||
kwargs["device"] = self.ultralytics_device
|
kwargs["device"] = self.ultralytics_device
|
||||||
|
kwargs["classes"] = args.ad_model_classes
|
||||||
|
|
||||||
with change_torch_load():
|
with change_torch_load():
|
||||||
pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)
|
pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def test_ultralytics_hf_models(sample_image: Image.Image, model_name: str):
|
|||||||
|
|
||||||
|
|
||||||
def test_yolo_world_default(sample_image: Image.Image):
|
def test_yolo_world_default(sample_image: Image.Image):
|
||||||
result = ultralytics_predict("yolov8l-world.pt", sample_image)
|
result = ultralytics_predict("yolov8x-world.pt", sample_image)
|
||||||
assert result.preview is not None
|
assert result.preview is not None
|
||||||
|
|
||||||
|
|
||||||
@@ -44,5 +44,5 @@ def test_yolo_world_default(sample_image: Image.Image):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_yolo_world(sample_image2: Image.Image, klass: str):
|
def test_yolo_world(sample_image2: Image.Image, klass: str):
|
||||||
result = ultralytics_predict("yolov8l-world.pt", sample_image2, classes=klass)
|
result = ultralytics_predict("yolov8x-world.pt", sample_image2, classes=klass)
|
||||||
assert result.preview is not None
|
assert result.preview is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user