diff --git a/scripts/clip_interrogator_ext.py b/scripts/clip_interrogator_ext.py index 1d03a25..38dd1df 100644 --- a/scripts/clip_interrogator_ext.py +++ b/scripts/clip_interrogator_ext.py @@ -1,13 +1,16 @@ import gradio as gr import open_clip -import clip_interrogator +import os import torch +from PIL import Image + +import clip_interrogator from clip_interrogator import Config, Interrogator -from modules import devices, script_callbacks, shared, lowvram +from modules import devices, lowvram, script_callbacks, shared -__version__ = '0.0.5' +__version__ = '0.0.6' ci = None @@ -66,6 +69,21 @@ def image_analysis(image, clip_model_name): return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks +def interrogate(image, mode): + if mode == 'best': + prompt = ci.interrogate(image) + elif mode == 'caption': + prompt = ci.generate_caption(image) + elif mode == 'classic': + prompt = ci.interrogate_classic(image) + elif mode == 'fast': + prompt = ci.interrogate_fast(image) + elif mode == 'negative': + prompt = ci.interrogate_negative(image) + else: + raise Exception(f"Unknown mode {mode}") + return prompt + def image_to_prompt(image, mode, clip_model_name): shared.state.begin() shared.state.job = 'interrogate' @@ -76,17 +94,8 @@ def image_to_prompt(image, mode, clip_model_name): devices.torch_gc() load(clip_model_name) - image = image.convert('RGB') - - if mode == 'best': - prompt = ci.interrogate(image) - elif mode == 'classic': - prompt = ci.interrogate_classic(image) - elif mode == 'fast': - prompt = ci.interrogate_fast(image) - elif mode == 'negative': - prompt = ci.interrogate_negative(image) + prompt = interrogate(image, mode) except torch.cuda.OutOfMemoryError as e: prompt = "Ran out of VRAM" print(e) @@ -97,33 +106,6 @@ def image_to_prompt(image, mode, clip_model_name): shared.state.end() return prompt -def prompt_tab(): - with gr.Column(): - with gr.Row(): - image = gr.Image(type='pil', label="Image") - with gr.Column(): - mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best') - model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model') - prompt = gr.Textbox(label="Prompt") - with gr.Row(): - button = gr.Button("Generate", variant='primary') - unload_button = gr.Button("Unload") - button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt) - unload_button.click(unload) - -def analyze_tab(): - with gr.Column(): - with gr.Row(): - image = gr.Image(type='pil', label="Image") - model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model') - with gr.Row(): - medium = gr.Label(label="Medium", num_top_classes=5) - artist = gr.Label(label="Artist", num_top_classes=5) - movement = gr.Label(label="Movement", num_top_classes=5) - trending = gr.Label(label="Trending", num_top_classes=5) - flavor = gr.Label(label="Flavor", num_top_classes=5) - button = gr.Button("Analyze", variant='primary') - button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor]) def about_tab(): gr.Markdown("## 🕵️‍♂️ CLIP Interrogator 🕵️‍♂️") @@ -143,12 +125,101 @@ def about_tab(): vram_total_mb = torch.cuda.get_device_properties(device).total_memory / (1024**2) gr.Markdown(f"GPU VRAM: {vram_total_mb:.2f}MB") +def analyze_tab(): + with gr.Column(): + with gr.Row(): + image = gr.Image(type='pil', label="Image") + model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model') + with gr.Row(): + medium = gr.Label(label="Medium", num_top_classes=5) + artist = gr.Label(label="Artist", num_top_classes=5) + movement = gr.Label(label="Movement", num_top_classes=5) + trending = gr.Label(label="Trending", num_top_classes=5) + flavor = gr.Label(label="Flavor", num_top_classes=5) + button = gr.Button("Analyze", variant='primary') + button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor]) + +def batch_tab(): + def batch_process(folder, model, mode): + if not os.path.exists(folder): + return f"Folder {folder} does not exist" + if not os.path.isdir(folder): + return "{folder} is not a directory" + + files = [f for f in os.listdir(folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] + if not files: + return "Folder has no images" + + shared.state.begin() + shared.state.job = 'batch interrogate' + + try: + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + devices.torch_gc() + + load(model) + + shared.total_tqdm.updateTotal(len(files)) + ci.config.quiet = True + + for file in files: + image = Image.open(os.path.join(folder, file)) + prompt = interrogate(image, mode) + txt_file = os.path.splitext(file)[0] + ".txt" + with open(os.path.join(folder, txt_file), 'w', encoding='utf-8') as f: + f.write(prompt) + + shared.total_tqdm.update() + if shared.state.interrupted: + break + + ci.config.quiet = False + unload() + except torch.cuda.OutOfMemoryError as e: + print(e) + print("Ran out of VRAM!") + except RuntimeError as e: + print(e) + shared.state.end() + shared.total_tqdm.clear() + + with gr.Column(): + with gr.Row(): + folder = gr.Text(label="Images folder", value="", interactive=True) + with gr.Row(): + model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model') + mode = gr.Radio(['caption', 'best', 'fast', 'classic', 'negative'], label='Mode', value='fast') + with gr.Row(): + button = gr.Button("Go!", variant='primary') + interrupt = gr.Button('Interrupt', visible=True) + interrupt.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[]) + + button.click(batch_process, inputs=[folder, model, mode], outputs=[]) + +def prompt_tab(): + with gr.Column(): + with gr.Row(): + image = gr.Image(type='pil', label="Image") + with gr.Column(): + mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best') + model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model') + prompt = gr.Textbox(label="Prompt") + with gr.Row(): + button = gr.Button("Generate", variant='primary') + unload_button = gr.Button("Unload") + button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt) + unload_button.click(unload) + + def add_tab(): with gr.Blocks(analytics_enabled=False) as ui: with gr.Tab("Prompt"): prompt_tab() with gr.Tab("Analyze"): analyze_tab() + with gr.Tab("Batch"): + batch_tab() with gr.Tab("About"): about_tab()