diff --git a/adetailer/__version__.py b/adetailer/__version__.py index 869aead..9d86149 100644 --- a/adetailer/__version__.py +++ b/adetailer/__version__.py @@ -1 +1 @@ -__version__ = "24.1.2" +__version__ = "24.3.0.dev0" diff --git a/adetailer/args.py b/adetailer/args.py index b0f311e..bda899d 100644 --- a/adetailer/args.py +++ b/adetailer/args.py @@ -42,6 +42,7 @@ class ArgsList(UserList): class ADetailerArgs(BaseModel, extra=Extra.forbid): ad_model: str = "None" + ad_model_classes: str = "" ad_prompt: str = "" ad_negative_prompt: str = "" ad_confidence: confloat(ge=0.0, le=1.0) = 0.3 @@ -111,6 +112,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): p = {name: getattr(self, attr) for attr, name in ALL_ARGS} ppop = partial(self.ppop, p) + ppop("ADetailer model classes") ppop("ADetailer prompt") ppop("ADetailer negative prompt") ppop("ADetailer mask only top k largest", cond=0) @@ -183,6 +185,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): _all_args = [ ("ad_model", "ADetailer model"), + ("ad_model_classes", "ADetailer model classes"), ("ad_prompt", "ADetailer prompt"), ("ad_negative_prompt", "ADetailer negative prompt"), ("ad_confidence", "ADetailer confidence"), diff --git a/adetailer/common.py b/adetailer/common.py index 7d0460d..5a8b1ab 100644 --- a/adetailer/common.py +++ b/adetailer/common.py @@ -60,6 +60,7 @@ def get_models( ) models.update( { + "yolov8x-world.pt": "yolov8x-world.pt", "mediapipe_face_full": None, "mediapipe_face_short": None, "mediapipe_face_mesh": None, diff --git a/adetailer/ui.py b/adetailer/ui.py index a044a53..cfd665d 100644 --- a/adetailer/ui.py +++ b/adetailer/ui.py @@ -85,6 +85,14 @@ def on_generate_click(state: dict, *values: Any): return state +def on_ad_model_update(model: str): + if "-world" in model: + return gr.update( + visible=True, placeholder="Comma separated class names to detect." + ) + return gr.update(visible=False, placeholder="") + + def on_cn_model_update(cn_model_name: str): cn_model_name = cn_model_name.replace("inpaint_depth", "depth") for t in cn_module_choices: @@ -177,6 +185,20 @@ def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo): elem_id=eid("ad_model"), ) + w.ad_model_classes = gr.Textbox( + label="ADetailer model classes" + suffix(n), + value="", + visible=False, + elem_id=eid("ad_classes"), + ) + + w.ad_model_classes.change( + on_ad_model_update, + inputs=w.ad_model, + outputs=w.ad_model_classes, + queue=False, + ) + with gr.Group(): with gr.Row(elem_id=eid("ad_toprow_prompt")): w.ad_prompt = gr.Textbox( diff --git a/adetailer/ultralytics.py b/adetailer/ultralytics.py index 8483b12..de19ff7 100644 --- a/adetailer/ultralytics.py +++ b/adetailer/ultralytics.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from typing import TYPE_CHECKING import cv2 from PIL import Image @@ -9,6 +10,10 @@ from torchvision.transforms.functional import to_pil_image from adetailer import PredictOutput from adetailer.common import create_mask_from_bbox +if TYPE_CHECKING: + import torch + from ultralytics import YOLO, YOLOWorld + def ultralytics_predict( model_path: str | Path, @@ -39,14 +44,14 @@ def ultralytics_predict( return PredictOutput(bboxes=bboxes, masks=masks, preview=preview) -def apply_classes(model, model_path: str | Path, classes: str): +def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str): if not classes or "-world" not in Path(model_path).stem: return parsed = [c.strip() for c in classes.split(",")] model.set_classes(parsed) -def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]: +def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]: """ Parameters ---------- @@ -54,7 +59,7 @@ def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]: The device can be CUDA, but `to_pil_image` takes care of that. shape: tuple[int, int] - (width, height) of the original image + (W, H) of the original image """ n = masks.shape[0] return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)] diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 747798c..8fa8352 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -695,6 +695,7 @@ class AfterDetailerScript(scripts.Script): predictor = ultralytics_predict ad_model = self.get_ad_model(args.ad_model) kwargs["device"] = self.ultralytics_device + kwargs["classes"] = args.ad_model_classes with change_torch_load(): pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs) diff --git a/tests/test_ultralytics.py b/tests/test_ultralytics.py index ad855ad..f3885de 100644 --- a/tests/test_ultralytics.py +++ b/tests/test_ultralytics.py @@ -28,7 +28,7 @@ def test_ultralytics_hf_models(sample_image: Image.Image, model_name: str): def test_yolo_world_default(sample_image: Image.Image): - result = ultralytics_predict("yolov8l-world.pt", sample_image) + result = ultralytics_predict("yolov8x-world.pt", sample_image) assert result.preview is not None @@ -44,5 +44,5 @@ def test_yolo_world_default(sample_image: Image.Image): ], ) def test_yolo_world(sample_image2: Image.Image, klass: str): - result = ultralytics_predict("yolov8l-world.pt", sample_image2, classes=klass) + result = ultralytics_predict("yolov8x-world.pt", sample_image2, classes=klass) assert result.preview is not None