Merge pull request #57 from hfg-gmuend/main

added API routes
This commit is contained in:
pharmapsychotic
2023-06-27 14:06:31 -05:00
committed by GitHub
2 changed files with 87 additions and 1 deletions

View File

@@ -14,3 +14,15 @@ This extension adds a tab for [CLIP Interrogator](https://github.com/pharmapsych
* Paste `https://github.com/pharmapsychotic/clip-interrogator-ext` and click Install
* Check in your terminal window if there are any errors (if so let me know!)
* Restart the Web UI and you should see a new **Interrogator** tab
## API
The CLIP Interrogator exposes a simple API to interact with the extension which is
documented on the /docs page under /interrogator/* (using --api flag when starting the Web UI)
* /interrogator/models
* lists all available models for interrogation
* /interrogator/prompt
* returns a prompt for the given image, model and mode
* /interrogator/analyze
* returns a list of words and their scores for the given image, model

View File

@@ -3,6 +3,7 @@ import gradio as gr
import open_clip
import os
import torch
import base64
from PIL import Image
@@ -11,7 +12,12 @@ from clip_interrogator import Config, Interrogator
from modules import devices, lowvram, script_callbacks, shared
__version__ = '0.1.5'
from pydantic import BaseModel, Field
from fastapi import FastAPI
from fastapi.exceptions import HTTPException
from io import BytesIO
__version__ = "0.1.6"
ci = None
low_vram = False
@@ -296,4 +302,72 @@ def add_tab():
return [(ui, "Interrogator", "interrogator")]
# decode_base64_to_image from modules/api/api.py, could be imported from there
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
class InterrogatorAnalyzeRequest(BaseModel):
image: str = Field(
default="",
title="Image",
description="Image to work on, must be a Base64 string containing the image's data.",
)
clip_model_name: str = Field(
default="ViT-L-14/openai",
title="Model",
description="The interrogate model used. See the models endpoint for a list of available models.",
)
class InterrogatorPromptRequest(InterrogatorAnalyzeRequest):
mode: str = Field(
default="fast",
title="Mode",
description="The mode used to generate the prompt. Can be one of: best, fast, classic, negative.",
)
def mount_interrogator_api(_: gr.Blocks, app: FastAPI):
@app.get("/interrogator/models")
async def get_models():
return ["/".join(x) for x in open_clip.list_pretrained()]
@app.post("/interrogator/prompt")
async def get_prompt(analyzereq: InterrogatorPromptRequest):
image_b64 = analyzereq.image
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
img = decode_base64_to_image(image_b64)
prompt = image_to_prompt(img, analyzereq.mode, analyzereq.clip_model_name)
return {"prompt": prompt}
@app.post("/interrogator/analyze")
async def analyze(analyzereq: InterrogatorAnalyzeRequest):
image_b64 = analyzereq.image
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
img = decode_base64_to_image(image_b64)
(
medium_ranks,
artist_ranks,
movement_ranks,
trending_ranks,
flavor_ranks,
) = image_analysis(img, analyzereq.clip_model_name)
return {
"medium": medium_ranks,
"artist": artist_ranks,
"movement": movement_ranks,
"trending": trending_ranks,
"flavor": flavor_ranks,
}
script_callbacks.on_app_started(mount_interrogator_api)
script_callbacks.on_ui_tabs(add_tab)