diff --git a/README.md b/README.md index bda9ca3..883be13 100644 --- a/README.md +++ b/README.md @@ -14,3 +14,15 @@ This extension adds a tab for [CLIP Interrogator](https://github.com/pharmapsych * Paste `https://github.com/pharmapsychotic/clip-interrogator-ext` and click Install * Check in your terminal window if there are any errors (if so let me know!) * Restart the Web UI and you should see a new **Interrogator** tab + + +## API + +The CLIP Interrogator exposes a simple API to interact with the extension which is +documented on the /docs page under /interrogator/* (using --api flag when starting the Web UI) +* /interrogator/models + * lists all available models for interrogation +* /interrogator/prompt + * returns a prompt for the given image, model and mode +* /interrogator/analyze + * returns a list of words and their scores for the given image, model \ No newline at end of file diff --git a/scripts/clip_interrogator_ext.py b/scripts/clip_interrogator_ext.py index 71245a3..c04c33e 100644 --- a/scripts/clip_interrogator_ext.py +++ b/scripts/clip_interrogator_ext.py @@ -3,6 +3,7 @@ import gradio as gr import open_clip import os import torch +import base64 from PIL import Image @@ -11,7 +12,12 @@ from clip_interrogator import Config, Interrogator from modules import devices, lowvram, script_callbacks, shared -__version__ = '0.1.5' +from pydantic import BaseModel, Field +from fastapi import FastAPI +from fastapi.exceptions import HTTPException +from io import BytesIO + +__version__ = "0.1.6" ci = None low_vram = False @@ -296,4 +302,72 @@ def add_tab(): return [(ui, "Interrogator", "interrogator")] +# decode_base64_to_image from modules/api/api.py, could be imported from there +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as e: + raise HTTPException(status_code=500, detail="Invalid encoded image") from e + +class InterrogatorAnalyzeRequest(BaseModel): + image: str = Field( + default="", + title="Image", + description="Image to work on, must be a Base64 string containing the image's data.", + ) + clip_model_name: str = Field( + default="ViT-L-14/openai", + title="Model", + description="The interrogate model used. See the models endpoint for a list of available models.", + ) + +class InterrogatorPromptRequest(InterrogatorAnalyzeRequest): + mode: str = Field( + default="fast", + title="Mode", + description="The mode used to generate the prompt. Can be one of: best, fast, classic, negative.", + ) + +def mount_interrogator_api(_: gr.Blocks, app: FastAPI): + @app.get("/interrogator/models") + async def get_models(): + return ["/".join(x) for x in open_clip.list_pretrained()] + + @app.post("/interrogator/prompt") + async def get_prompt(analyzereq: InterrogatorPromptRequest): + image_b64 = analyzereq.image + if image_b64 is None: + raise HTTPException(status_code=404, detail="Image not found") + + img = decode_base64_to_image(image_b64) + prompt = image_to_prompt(img, analyzereq.mode, analyzereq.clip_model_name) + return {"prompt": prompt} + + @app.post("/interrogator/analyze") + async def analyze(analyzereq: InterrogatorAnalyzeRequest): + image_b64 = analyzereq.image + if image_b64 is None: + raise HTTPException(status_code=404, detail="Image not found") + + img = decode_base64_to_image(image_b64) + ( + medium_ranks, + artist_ranks, + movement_ranks, + trending_ranks, + flavor_ranks, + ) = image_analysis(img, analyzereq.clip_model_name) + return { + "medium": medium_ranks, + "artist": artist_ranks, + "movement": movement_ranks, + "trending": trending_ranks, + "flavor": flavor_ranks, + } + + +script_callbacks.on_app_started(mount_interrogator_api) script_callbacks.on_ui_tabs(add_tab)