diff --git a/adetailer/ultralytics.py b/adetailer/ultralytics.py index b44703e..087ff03 100644 --- a/adetailer/ultralytics.py +++ b/adetailer/ultralytics.py @@ -9,17 +9,25 @@ from adetailer import PredictOutput from adetailer.common import create_mask_from_bbox +def load_yolo(model_path: str | Path): + from ultralytics import YOLO + + try: + return YOLO(model_path) + except ModuleNotFoundError: + YOLO("yolov8n.pt") + return YOLO(model_path) + + def ultralytics_predict( model_path: str | Path, image: Image.Image, confidence: float = 0.3, device: str = "", ) -> PredictOutput: - from ultralytics import YOLO - model_path = str(model_path) - model = YOLO(model_path) + model = load_yolo(model_path) pred = model(image, conf=confidence, device=device) bboxes = pred[0].boxes.xyxy.cpu().numpy()