mirror of
https://github.com/pharmapsychotic/clip-interrogator-ext.git
synced 2026-01-26 19:29:53 +00:00
Add batch mode by popular request
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
import gradio as gr
|
||||
import open_clip
|
||||
import clip_interrogator
|
||||
import os
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import clip_interrogator
|
||||
from clip_interrogator import Config, Interrogator
|
||||
|
||||
from modules import devices, script_callbacks, shared, lowvram
|
||||
from modules import devices, lowvram, script_callbacks, shared
|
||||
|
||||
__version__ = '0.0.5'
|
||||
__version__ = '0.0.6'
|
||||
|
||||
ci = None
|
||||
|
||||
@@ -66,6 +69,21 @@ def image_analysis(image, clip_model_name):
|
||||
|
||||
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
|
||||
|
||||
def interrogate(image, mode):
|
||||
if mode == 'best':
|
||||
prompt = ci.interrogate(image)
|
||||
elif mode == 'caption':
|
||||
prompt = ci.generate_caption(image)
|
||||
elif mode == 'classic':
|
||||
prompt = ci.interrogate_classic(image)
|
||||
elif mode == 'fast':
|
||||
prompt = ci.interrogate_fast(image)
|
||||
elif mode == 'negative':
|
||||
prompt = ci.interrogate_negative(image)
|
||||
else:
|
||||
raise Exception(f"Unknown mode {mode}")
|
||||
return prompt
|
||||
|
||||
def image_to_prompt(image, mode, clip_model_name):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
@@ -76,17 +94,8 @@ def image_to_prompt(image, mode, clip_model_name):
|
||||
devices.torch_gc()
|
||||
|
||||
load(clip_model_name)
|
||||
|
||||
image = image.convert('RGB')
|
||||
|
||||
if mode == 'best':
|
||||
prompt = ci.interrogate(image)
|
||||
elif mode == 'classic':
|
||||
prompt = ci.interrogate_classic(image)
|
||||
elif mode == 'fast':
|
||||
prompt = ci.interrogate_fast(image)
|
||||
elif mode == 'negative':
|
||||
prompt = ci.interrogate_negative(image)
|
||||
prompt = interrogate(image, mode)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
prompt = "Ran out of VRAM"
|
||||
print(e)
|
||||
@@ -97,33 +106,6 @@ def image_to_prompt(image, mode, clip_model_name):
|
||||
shared.state.end()
|
||||
return prompt
|
||||
|
||||
def prompt_tab():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
image = gr.Image(type='pil', label="Image")
|
||||
with gr.Column():
|
||||
mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
|
||||
model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
|
||||
prompt = gr.Textbox(label="Prompt")
|
||||
with gr.Row():
|
||||
button = gr.Button("Generate", variant='primary')
|
||||
unload_button = gr.Button("Unload")
|
||||
button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt)
|
||||
unload_button.click(unload)
|
||||
|
||||
def analyze_tab():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
image = gr.Image(type='pil', label="Image")
|
||||
model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
|
||||
with gr.Row():
|
||||
medium = gr.Label(label="Medium", num_top_classes=5)
|
||||
artist = gr.Label(label="Artist", num_top_classes=5)
|
||||
movement = gr.Label(label="Movement", num_top_classes=5)
|
||||
trending = gr.Label(label="Trending", num_top_classes=5)
|
||||
flavor = gr.Label(label="Flavor", num_top_classes=5)
|
||||
button = gr.Button("Analyze", variant='primary')
|
||||
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
|
||||
|
||||
def about_tab():
|
||||
gr.Markdown("## 🕵️♂️ CLIP Interrogator 🕵️♂️")
|
||||
@@ -143,12 +125,101 @@ def about_tab():
|
||||
vram_total_mb = torch.cuda.get_device_properties(device).total_memory / (1024**2)
|
||||
gr.Markdown(f"GPU VRAM: {vram_total_mb:.2f}MB")
|
||||
|
||||
def analyze_tab():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
image = gr.Image(type='pil', label="Image")
|
||||
model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
|
||||
with gr.Row():
|
||||
medium = gr.Label(label="Medium", num_top_classes=5)
|
||||
artist = gr.Label(label="Artist", num_top_classes=5)
|
||||
movement = gr.Label(label="Movement", num_top_classes=5)
|
||||
trending = gr.Label(label="Trending", num_top_classes=5)
|
||||
flavor = gr.Label(label="Flavor", num_top_classes=5)
|
||||
button = gr.Button("Analyze", variant='primary')
|
||||
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
|
||||
|
||||
def batch_tab():
|
||||
def batch_process(folder, model, mode):
|
||||
if not os.path.exists(folder):
|
||||
return f"Folder {folder} does not exist"
|
||||
if not os.path.isdir(folder):
|
||||
return "{folder} is not a directory"
|
||||
|
||||
files = [f for f in os.listdir(folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||||
if not files:
|
||||
return "Folder has no images"
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'batch interrogate'
|
||||
|
||||
try:
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
|
||||
load(model)
|
||||
|
||||
shared.total_tqdm.updateTotal(len(files))
|
||||
ci.config.quiet = True
|
||||
|
||||
for file in files:
|
||||
image = Image.open(os.path.join(folder, file))
|
||||
prompt = interrogate(image, mode)
|
||||
txt_file = os.path.splitext(file)[0] + ".txt"
|
||||
with open(os.path.join(folder, txt_file), 'w', encoding='utf-8') as f:
|
||||
f.write(prompt)
|
||||
|
||||
shared.total_tqdm.update()
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
ci.config.quiet = False
|
||||
unload()
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
print(e)
|
||||
print("Ran out of VRAM!")
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
folder = gr.Text(label="Images folder", value="", interactive=True)
|
||||
with gr.Row():
|
||||
model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
|
||||
mode = gr.Radio(['caption', 'best', 'fast', 'classic', 'negative'], label='Mode', value='fast')
|
||||
with gr.Row():
|
||||
button = gr.Button("Go!", variant='primary')
|
||||
interrupt = gr.Button('Interrupt', visible=True)
|
||||
interrupt.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[])
|
||||
|
||||
button.click(batch_process, inputs=[folder, model, mode], outputs=[])
|
||||
|
||||
def prompt_tab():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
image = gr.Image(type='pil', label="Image")
|
||||
with gr.Column():
|
||||
mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
|
||||
model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
|
||||
prompt = gr.Textbox(label="Prompt")
|
||||
with gr.Row():
|
||||
button = gr.Button("Generate", variant='primary')
|
||||
unload_button = gr.Button("Unload")
|
||||
button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt)
|
||||
unload_button.click(unload)
|
||||
|
||||
|
||||
def add_tab():
|
||||
with gr.Blocks(analytics_enabled=False) as ui:
|
||||
with gr.Tab("Prompt"):
|
||||
prompt_tab()
|
||||
with gr.Tab("Analyze"):
|
||||
analyze_tab()
|
||||
with gr.Tab("Batch"):
|
||||
batch_tab()
|
||||
with gr.Tab("About"):
|
||||
about_tab()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user