mirror of
https://github.com/pharmapsychotic/clip-interrogator-ext.git
synced 2026-04-30 03:01:42 +00:00
Different outputs modes for Batch tab
Save to individual text files, all in single text file, or csv file.
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import csv
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import open_clip
|
import open_clip
|
||||||
import os
|
import os
|
||||||
@@ -10,10 +11,44 @@ from clip_interrogator import Config, Interrogator
|
|||||||
|
|
||||||
from modules import devices, lowvram, script_callbacks, shared
|
from modules import devices, lowvram, script_callbacks, shared
|
||||||
|
|
||||||
__version__ = '0.0.6'
|
__version__ = '0.0.7'
|
||||||
|
|
||||||
ci = None
|
ci = None
|
||||||
|
|
||||||
|
BATCH_OUTPUT_MODES = [
|
||||||
|
'Text file for each image',
|
||||||
|
'Single text file with all prompts',
|
||||||
|
'csv file with columns for filenames and prompts',
|
||||||
|
]
|
||||||
|
|
||||||
|
class BatchWriter:
|
||||||
|
def __init__(self, folder, mode):
|
||||||
|
self.folder = folder
|
||||||
|
self.mode = mode
|
||||||
|
self.csv, self.file = None, None
|
||||||
|
if mode == BATCH_OUTPUT_MODES[1]:
|
||||||
|
self.file = open(os.path.join(folder, 'batch.txt'), 'w', encoding='utf-8')
|
||||||
|
elif mode == BATCH_OUTPUT_MODES[2]:
|
||||||
|
self.file = open(os.path.join(folder, 'batch.csv'), 'w', encoding='utf-8', newline='')
|
||||||
|
self.csv = csv.writer(self.file, quoting=csv.QUOTE_MINIMAL)
|
||||||
|
self.csv.writerow(['filename', 'prompt'])
|
||||||
|
|
||||||
|
def add(self, file, prompt):
|
||||||
|
if self.mode == BATCH_OUTPUT_MODES[0]:
|
||||||
|
txt_file = os.path.splitext(file)[0] + ".txt"
|
||||||
|
with open(os.path.join(self.folder, txt_file), 'w', encoding='utf-8') as f:
|
||||||
|
f.write(prompt)
|
||||||
|
elif self.mode == BATCH_OUTPUT_MODES[1]:
|
||||||
|
self.file.write(f"{prompt}\n")
|
||||||
|
elif self.mode == BATCH_OUTPUT_MODES[2]:
|
||||||
|
self.file.write(f"{file},{prompt}\n")
|
||||||
|
self.csv.writerow([file, prompt])
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.file is not None:
|
||||||
|
self.file.close()
|
||||||
|
|
||||||
|
|
||||||
def load(clip_model_name):
|
def load(clip_model_name):
|
||||||
global ci
|
global ci
|
||||||
if ci is None:
|
if ci is None:
|
||||||
@@ -69,15 +104,15 @@ def image_analysis(image, clip_model_name):
|
|||||||
|
|
||||||
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
|
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
|
||||||
|
|
||||||
def interrogate(image, mode):
|
def interrogate(image, mode, caption=None):
|
||||||
if mode == 'best':
|
if mode == 'best':
|
||||||
prompt = ci.interrogate(image)
|
prompt = ci.interrogate(image, caption=caption)
|
||||||
elif mode == 'caption':
|
elif mode == 'caption':
|
||||||
prompt = ci.generate_caption(image)
|
prompt = ci.generate_caption(image) if caption is None else caption
|
||||||
elif mode == 'classic':
|
elif mode == 'classic':
|
||||||
prompt = ci.interrogate_classic(image)
|
prompt = ci.interrogate_classic(image, caption=caption)
|
||||||
elif mode == 'fast':
|
elif mode == 'fast':
|
||||||
prompt = ci.interrogate_fast(image)
|
prompt = ci.interrogate_fast(image, caption=caption)
|
||||||
elif mode == 'negative':
|
elif mode == 'negative':
|
||||||
prompt = ci.interrogate_negative(image)
|
prompt = ci.interrogate_negative(image)
|
||||||
else:
|
else:
|
||||||
@@ -140,7 +175,7 @@ def analyze_tab():
|
|||||||
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
|
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
|
||||||
|
|
||||||
def batch_tab():
|
def batch_tab():
|
||||||
def batch_process(folder, model, mode):
|
def batch_process(folder, model, mode, output_mode):
|
||||||
if not os.path.exists(folder):
|
if not os.path.exists(folder):
|
||||||
return f"Folder {folder} does not exist"
|
return f"Folder {folder} does not exist"
|
||||||
if not os.path.isdir(folder):
|
if not os.path.isdir(folder):
|
||||||
@@ -163,17 +198,28 @@ def batch_tab():
|
|||||||
shared.total_tqdm.updateTotal(len(files))
|
shared.total_tqdm.updateTotal(len(files))
|
||||||
ci.config.quiet = True
|
ci.config.quiet = True
|
||||||
|
|
||||||
|
# generate captions in first pass
|
||||||
|
captions = []
|
||||||
for file in files:
|
for file in files:
|
||||||
image = Image.open(os.path.join(folder, file))
|
|
||||||
prompt = interrogate(image, mode)
|
|
||||||
txt_file = os.path.splitext(file)[0] + ".txt"
|
|
||||||
with open(os.path.join(folder, txt_file), 'w', encoding='utf-8') as f:
|
|
||||||
f.write(prompt)
|
|
||||||
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
image = Image.open(os.path.join(folder, file))
|
||||||
|
captions.append(ci.generate_caption(image))
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
|
# interrogate in second pass
|
||||||
|
writer = BatchWriter(folder, output_mode)
|
||||||
|
shared.total_tqdm.clear()
|
||||||
|
shared.total_tqdm.updateTotal(len(files))
|
||||||
|
for idx, file in enumerate(files):
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
image = Image.open(os.path.join(folder, file))
|
||||||
|
prompt = interrogate(image, mode, caption=captions[idx])
|
||||||
|
writer.add(file, prompt)
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
|
writer.close()
|
||||||
ci.config.quiet = False
|
ci.config.quiet = False
|
||||||
unload()
|
unload()
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
@@ -189,13 +235,14 @@ def batch_tab():
|
|||||||
folder = gr.Text(label="Images folder", value="", interactive=True)
|
folder = gr.Text(label="Images folder", value="", interactive=True)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
|
model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
|
||||||
mode = gr.Radio(['caption', 'best', 'fast', 'classic', 'negative'], label='Mode', value='fast')
|
mode = gr.Radio(['caption', 'best', 'fast', 'classic', 'negative'], label='Prompt Mode', value='fast')
|
||||||
|
output_mode = gr.Dropdown(BATCH_OUTPUT_MODES, value=BATCH_OUTPUT_MODES[0], label='Output Mode')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
button = gr.Button("Go!", variant='primary')
|
button = gr.Button("Go!", variant='primary')
|
||||||
interrupt = gr.Button('Interrupt', visible=True)
|
interrupt = gr.Button('Interrupt', visible=True)
|
||||||
interrupt.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[])
|
interrupt.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[])
|
||||||
|
|
||||||
button.click(batch_process, inputs=[folder, model, mode], outputs=[])
|
button.click(batch_process, inputs=[folder, model, mode, output_mode], outputs=[])
|
||||||
|
|
||||||
def prompt_tab():
|
def prompt_tab():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
|||||||
Reference in New Issue
Block a user