renaming analyse to analyze

This commit is contained in:
Christopher Pietsch
2023-06-27 19:25:35 +02:00
parent f559d2d863
commit db4340b30f
2 changed files with 10 additions and 11 deletions

View File

@@ -24,5 +24,5 @@ documented on the /docs page under /interrogator/* (using --api flag when starti
* lists all available models for interrogation * lists all available models for interrogation
* /interrogator/prompt * /interrogator/prompt
* returns a prompt for the given image, model and mode * returns a prompt for the given image, model and mode
* /interrogator/analyse * /interrogator/analyze
* returns a list of words and their scores for the given image, model * returns a list of words and their scores for the given image, model

View File

@@ -312,7 +312,7 @@ def decode_base64_to_image(encoding):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e raise HTTPException(status_code=500, detail="Invalid encoded image") from e
class InterrogatorAnalyseRequest(BaseModel): class InterrogatorAnalyzeRequest(BaseModel):
image: str = Field( image: str = Field(
default="", default="",
title="Image", title="Image",
@@ -324,32 +324,31 @@ class InterrogatorAnalyseRequest(BaseModel):
description="The interrogate model used. See the models endpoint for a list of available models.", description="The interrogate model used. See the models endpoint for a list of available models.",
) )
class InterrogatorPromptRequest(InterrogatorAnalyseRequest): class InterrogatorPromptRequest(InterrogatorAnalyzeRequest):
mode: str = Field( mode: str = Field(
default="fast", default="fast",
title="Mode", title="Mode",
description="The mode used to generate the prompt. Can be one of: best, fast, classic, negative.", description="The mode used to generate the prompt. Can be one of: best, fast, classic, negative.",
) )
def mount_interrogator_api(_: gr.Blocks, app: FastAPI): def mount_interrogator_api(_: gr.Blocks, app: FastAPI):
@app.get("/interrogator/models") @app.get("/interrogator/models")
async def get_models(): async def get_models():
return ["/".join(x) for x in open_clip.list_pretrained()] return ["/".join(x) for x in open_clip.list_pretrained()]
@app.post("/interrogator/prompt") @app.post("/interrogator/prompt")
async def get_prompt(analysereq: InterrogatorPromptRequest): async def get_prompt(analyzereq: InterrogatorPromptRequest):
image_b64 = analysereq.image image_b64 = analyzereq.image
if image_b64 is None: if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found") raise HTTPException(status_code=404, detail="Image not found")
img = decode_base64_to_image(image_b64) img = decode_base64_to_image(image_b64)
prompt = image_to_prompt(img, analysereq.mode, analysereq.clip_model_name) prompt = image_to_prompt(img, analyzereq.mode, analyzereq.clip_model_name)
return {"prompt": prompt} return {"prompt": prompt}
@app.post("/interrogator/analyse") @app.post("/interrogator/analyze")
async def analyse(analysereq: InterrogatorAnalyseRequest): async def analyze(analyzereq: InterrogatorAnalyzeRequest):
image_b64 = analysereq.image image_b64 = analyzereq.image
if image_b64 is None: if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found") raise HTTPException(status_code=404, detail="Image not found")
@@ -360,7 +359,7 @@ def mount_interrogator_api(_: gr.Blocks, app: FastAPI):
movement_ranks, movement_ranks,
trending_ranks, trending_ranks,
flavor_ranks, flavor_ranks,
) = image_analysis(img, analysereq.clip_model_name) ) = image_analysis(img, analyzereq.clip_model_name)
return { return {
"medium": medium_ranks, "medium": medium_ranks,
"artist": artist_ranks, "artist": artist_ranks,