diff --git a/scripts/clip_interrogator_ext.py b/scripts/clip_interrogator_ext.py index 38dd1df..b9edf15 100644 --- a/scripts/clip_interrogator_ext.py +++ b/scripts/clip_interrogator_ext.py @@ -1,3 +1,4 @@ +import csv import gradio as gr import open_clip import os @@ -10,10 +11,44 @@ from clip_interrogator import Config, Interrogator from modules import devices, lowvram, script_callbacks, shared -__version__ = '0.0.6' +__version__ = '0.0.7' ci = None +BATCH_OUTPUT_MODES = [ + 'Text file for each image', + 'Single text file with all prompts', + 'csv file with columns for filenames and prompts', +] + +class BatchWriter: + def __init__(self, folder, mode): + self.folder = folder + self.mode = mode + self.csv, self.file = None, None + if mode == BATCH_OUTPUT_MODES[1]: + self.file = open(os.path.join(folder, 'batch.txt'), 'w', encoding='utf-8') + elif mode == BATCH_OUTPUT_MODES[2]: + self.file = open(os.path.join(folder, 'batch.csv'), 'w', encoding='utf-8', newline='') + self.csv = csv.writer(self.file, quoting=csv.QUOTE_MINIMAL) + self.csv.writerow(['filename', 'prompt']) + + def add(self, file, prompt): + if self.mode == BATCH_OUTPUT_MODES[0]: + txt_file = os.path.splitext(file)[0] + ".txt" + with open(os.path.join(self.folder, txt_file), 'w', encoding='utf-8') as f: + f.write(prompt) + elif self.mode == BATCH_OUTPUT_MODES[1]: + self.file.write(f"{prompt}\n") + elif self.mode == BATCH_OUTPUT_MODES[2]: + self.file.write(f"{file},{prompt}\n") + self.csv.writerow([file, prompt]) + + def close(self): + if self.file is not None: + self.file.close() + + def load(clip_model_name): global ci if ci is None: @@ -69,15 +104,15 @@ def image_analysis(image, clip_model_name): return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks -def interrogate(image, mode): +def interrogate(image, mode, caption=None): if mode == 'best': - prompt = ci.interrogate(image) + prompt = ci.interrogate(image, caption=caption) elif mode == 'caption': - prompt = ci.generate_caption(image) + prompt = ci.generate_caption(image) if caption is None else caption elif mode == 'classic': - prompt = ci.interrogate_classic(image) + prompt = ci.interrogate_classic(image, caption=caption) elif mode == 'fast': - prompt = ci.interrogate_fast(image) + prompt = ci.interrogate_fast(image, caption=caption) elif mode == 'negative': prompt = ci.interrogate_negative(image) else: @@ -140,7 +175,7 @@ def analyze_tab(): button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor]) def batch_tab(): - def batch_process(folder, model, mode): + def batch_process(folder, model, mode, output_mode): if not os.path.exists(folder): return f"Folder {folder} does not exist" if not os.path.isdir(folder): @@ -163,17 +198,28 @@ def batch_tab(): shared.total_tqdm.updateTotal(len(files)) ci.config.quiet = True + # generate captions in first pass + captions = [] for file in files: - image = Image.open(os.path.join(folder, file)) - prompt = interrogate(image, mode) - txt_file = os.path.splitext(file)[0] + ".txt" - with open(os.path.join(folder, txt_file), 'w', encoding='utf-8') as f: - f.write(prompt) - - shared.total_tqdm.update() if shared.state.interrupted: break + image = Image.open(os.path.join(folder, file)) + captions.append(ci.generate_caption(image)) + shared.total_tqdm.update() + # interrogate in second pass + writer = BatchWriter(folder, output_mode) + shared.total_tqdm.clear() + shared.total_tqdm.updateTotal(len(files)) + for idx, file in enumerate(files): + if shared.state.interrupted: + break + image = Image.open(os.path.join(folder, file)) + prompt = interrogate(image, mode, caption=captions[idx]) + writer.add(file, prompt) + shared.total_tqdm.update() + + writer.close() ci.config.quiet = False unload() except torch.cuda.OutOfMemoryError as e: @@ -189,13 +235,14 @@ def batch_tab(): folder = gr.Text(label="Images folder", value="", interactive=True) with gr.Row(): model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model') - mode = gr.Radio(['caption', 'best', 'fast', 'classic', 'negative'], label='Mode', value='fast') + mode = gr.Radio(['caption', 'best', 'fast', 'classic', 'negative'], label='Prompt Mode', value='fast') + output_mode = gr.Dropdown(BATCH_OUTPUT_MODES, value=BATCH_OUTPUT_MODES[0], label='Output Mode') with gr.Row(): button = gr.Button("Go!", variant='primary') interrupt = gr.Button('Interrupt', visible=True) interrupt.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[]) - button.click(batch_process, inputs=[folder, model, mode], outputs=[]) + button.click(batch_process, inputs=[folder, model, mode, output_mode], outputs=[]) def prompt_tab(): with gr.Column():