From a690592efc9419492c51fefd996a33d08ae0b9d9 Mon Sep 17 00:00:00 2001 From: Christopher Pietsch Date: Tue, 27 Jun 2023 18:00:46 +0200 Subject: [PATCH 1/3] added API routes --- README.md | 12 +++++ scripts/clip_interrogator_ext.py | 77 +++++++++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bda9ca3..625e2cb 100644 --- a/README.md +++ b/README.md @@ -14,3 +14,15 @@ This extension adds a tab for [CLIP Interrogator](https://github.com/pharmapsych * Paste `https://github.com/pharmapsychotic/clip-interrogator-ext` and click Install * Check in your terminal window if there are any errors (if so let me know!) * Restart the Web UI and you should see a new **Interrogator** tab + + +## API + +The CLIP Interrogator exposes a simple API to interact with the extension which is +documented on the /docs page under /interrogator/* (using --api flag when starting the Web UI) +* /interrogator/models + * lists all available models for interrogation +* /interrogator/prompt + * returns a prompt for the given image, model and mode +* /interrogator/analyse + * returns a list of words and their scores for the given image, model \ No newline at end of file diff --git a/scripts/clip_interrogator_ext.py b/scripts/clip_interrogator_ext.py index 71245a3..6bd3533 100644 --- a/scripts/clip_interrogator_ext.py +++ b/scripts/clip_interrogator_ext.py @@ -3,6 +3,7 @@ import gradio as gr import open_clip import os import torch +import base64 from PIL import Image @@ -11,7 +12,12 @@ from clip_interrogator import Config, Interrogator from modules import devices, lowvram, script_callbacks, shared -__version__ = '0.1.5' +from pydantic import BaseModel, Field +from fastapi import FastAPI +from fastapi.exceptions import HTTPException +from io import BytesIO + +__version__ = "0.1.6" ci = None low_vram = False @@ -296,4 +302,73 @@ def add_tab(): return [(ui, "Interrogator", "interrogator")] +# decode_base64_to_image from modules/api/api.py, could be imported from there +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as e: + raise HTTPException(status_code=500, detail="Invalid encoded image") from e + +class InterrogatorAnalyseRequest(BaseModel): + image: str = Field( + default="", + title="Image", + description="Image to work on, must be a Base64 string containing the image's data.", + ) + clip_model_name: str = Field( + default="ViT-L-14/openai", + title="Model", + description="The interrogate model used. See the models endpoint for a list of available models.", + ) + +class InterrogatorPromptRequest(InterrogatorAnalyseRequest): + mode: str = Field( + default="fast", + title="Mode", + description="The mode used to generate the prompt. Can be one of: best, fast, classic, negative.", + ) + + +def mount_interrogator_api(_: gr.Blocks, app: FastAPI): + @app.get("/interrogator/models") + async def get_models(): + return ["/".join(x) for x in open_clip.list_pretrained()] + + @app.post("/interrogator/prompt") + async def get_prompt(analysereq: InterrogatorPromptRequest): + image_b64 = analysereq.image + if image_b64 is None: + raise HTTPException(status_code=404, detail="Image not found") + + img = decode_base64_to_image(image_b64) + prompt = image_to_prompt(img, analysereq.mode, analysereq.clip_model_name) + return {"prompt": prompt} + + @app.post("/interrogator/analyze") + async def analyze(analysereq: InterrogatorAnalyseRequest): + image_b64 = analysereq.image + if image_b64 is None: + raise HTTPException(status_code=404, detail="Image not found") + + img = decode_base64_to_image(image_b64) + ( + medium_ranks, + artist_ranks, + movement_ranks, + trending_ranks, + flavor_ranks, + ) = image_analysis(img, analysereq.clip_model_name) + return { + "medium": medium_ranks, + "artist": artist_ranks, + "movement": movement_ranks, + "trending": trending_ranks, + "flavor": flavor_ranks, + } + + +script_callbacks.on_app_started(mount_interrogator_api) script_callbacks.on_ui_tabs(add_tab) From f559d2d863a9a8b9cbd4223310db3f16b380b760 Mon Sep 17 00:00:00 2001 From: Christopher Pietsch Date: Tue, 27 Jun 2023 18:17:41 +0200 Subject: [PATCH 2/3] Typo: analyse instead of analyze --- scripts/clip_interrogator_ext.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/clip_interrogator_ext.py b/scripts/clip_interrogator_ext.py index 6bd3533..01461ef 100644 --- a/scripts/clip_interrogator_ext.py +++ b/scripts/clip_interrogator_ext.py @@ -347,8 +347,8 @@ def mount_interrogator_api(_: gr.Blocks, app: FastAPI): prompt = image_to_prompt(img, analysereq.mode, analysereq.clip_model_name) return {"prompt": prompt} - @app.post("/interrogator/analyze") - async def analyze(analysereq: InterrogatorAnalyseRequest): + @app.post("/interrogator/analyse") + async def analyse(analysereq: InterrogatorAnalyseRequest): image_b64 = analysereq.image if image_b64 is None: raise HTTPException(status_code=404, detail="Image not found") From db4340b30f85ba0c74ad53bf564ec4dd785de250 Mon Sep 17 00:00:00 2001 From: Christopher Pietsch Date: Tue, 27 Jun 2023 19:25:35 +0200 Subject: [PATCH 3/3] renaming analyse to analyze --- README.md | 2 +- scripts/clip_interrogator_ext.py | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 625e2cb..883be13 100644 --- a/README.md +++ b/README.md @@ -24,5 +24,5 @@ documented on the /docs page under /interrogator/* (using --api flag when starti * lists all available models for interrogation * /interrogator/prompt * returns a prompt for the given image, model and mode -* /interrogator/analyse +* /interrogator/analyze * returns a list of words and their scores for the given image, model \ No newline at end of file diff --git a/scripts/clip_interrogator_ext.py b/scripts/clip_interrogator_ext.py index 01461ef..c04c33e 100644 --- a/scripts/clip_interrogator_ext.py +++ b/scripts/clip_interrogator_ext.py @@ -312,7 +312,7 @@ def decode_base64_to_image(encoding): except Exception as e: raise HTTPException(status_code=500, detail="Invalid encoded image") from e -class InterrogatorAnalyseRequest(BaseModel): +class InterrogatorAnalyzeRequest(BaseModel): image: str = Field( default="", title="Image", @@ -324,32 +324,31 @@ class InterrogatorAnalyseRequest(BaseModel): description="The interrogate model used. See the models endpoint for a list of available models.", ) -class InterrogatorPromptRequest(InterrogatorAnalyseRequest): +class InterrogatorPromptRequest(InterrogatorAnalyzeRequest): mode: str = Field( default="fast", title="Mode", description="The mode used to generate the prompt. Can be one of: best, fast, classic, negative.", ) - def mount_interrogator_api(_: gr.Blocks, app: FastAPI): @app.get("/interrogator/models") async def get_models(): return ["/".join(x) for x in open_clip.list_pretrained()] @app.post("/interrogator/prompt") - async def get_prompt(analysereq: InterrogatorPromptRequest): - image_b64 = analysereq.image + async def get_prompt(analyzereq: InterrogatorPromptRequest): + image_b64 = analyzereq.image if image_b64 is None: raise HTTPException(status_code=404, detail="Image not found") img = decode_base64_to_image(image_b64) - prompt = image_to_prompt(img, analysereq.mode, analysereq.clip_model_name) + prompt = image_to_prompt(img, analyzereq.mode, analyzereq.clip_model_name) return {"prompt": prompt} - @app.post("/interrogator/analyse") - async def analyse(analysereq: InterrogatorAnalyseRequest): - image_b64 = analysereq.image + @app.post("/interrogator/analyze") + async def analyze(analyzereq: InterrogatorAnalyzeRequest): + image_b64 = analyzereq.image if image_b64 is None: raise HTTPException(status_code=404, detail="Image not found") @@ -360,7 +359,7 @@ def mount_interrogator_api(_: gr.Blocks, app: FastAPI): movement_ranks, trending_ranks, flavor_ranks, - ) = image_analysis(img, analysereq.clip_model_name) + ) = image_analysis(img, analyzereq.clip_model_name) return { "medium": medium_ranks, "artist": artist_ranks,