diff --git a/scripts/clip_interrogator_ext.py b/scripts/clip_interrogator_ext.py index fc68a61..074b9bc 100644 --- a/scripts/clip_interrogator_ext.py +++ b/scripts/clip_interrogator_ext.py @@ -4,7 +4,7 @@ import open_clip import os import torch -from PIL import Image +from PIL import Image, UnidentifiedImageError import clip_interrogator from clip_interrogator import Config, Interrogator @@ -55,7 +55,7 @@ def load(clip_model_name): print(f"Loading CLIP Interrogator {clip_interrogator.__version__}...") config = Config( - device=devices.get_optimal_device(), + device=devices.get_optimal_device(), cache_path = 'models/clip-interrogator', clip_model_name=clip_model_name, blip_model=shared.interrogator.load_blip_model().float() @@ -95,7 +95,7 @@ def image_analysis(image, clip_model_name): movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))} trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))} flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))} - + return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks def interrogate(image, mode, caption=None): @@ -117,7 +117,7 @@ def image_to_prompt(image, mode, clip_model_name): shared.state.begin() shared.state.job = 'interrogate' - try: + try: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() devices.torch_gc() @@ -169,13 +169,14 @@ def analyze_tab(): model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model') with gr.Row(): medium = gr.Label(label="Medium", num_top_classes=5) - artist = gr.Label(label="Artist", num_top_classes=5) + artist = gr.Label(label="Artist", num_top_classes=5) movement = gr.Label(label="Movement", num_top_classes=5) trending = gr.Label(label="Trending", num_top_classes=5) flavor = gr.Label(label="Flavor", num_top_classes=5) button = gr.Button("Analyze", variant='primary') button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor]) + def batch_tab(): def batch_process(folder, clip_model, mode, output_mode): if not os.path.exists(folder): @@ -190,7 +191,7 @@ def batch_tab(): shared.state.begin() shared.state.job = 'batch interrogate' - try: + try: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() devices.torch_gc() @@ -202,25 +203,40 @@ def batch_tab(): # generate captions in first pass captions = [] + for file in files: - if shared.state.interrupted: - break - image = Image.open(os.path.join(folder, file)).convert('RGB') - captions.append(ci.generate_caption(image)) - shared.total_tqdm.update() + try: + if shared.state.interrupted: + break + image = Image.open(os.path.join(folder, file)).convert('RGB') + captions.append(ci.generate_caption(image)) + shared.total_tqdm.update() + except OSError: + print(f" Could not read {file}; continuing") + continue + finally: + shared.total_tqdm.update() # interrogate in second pass writer = BatchWriter(folder, output_mode) shared.total_tqdm.clear() shared.total_tqdm.updateTotal(len(files)) for idx, file in enumerate(files): - if shared.state.interrupted: - break - image = Image.open(os.path.join(folder, file)).convert('RGB') - prompt = interrogate(image, mode, caption=captions[idx]) - writer.add(file, prompt) - shared.total_tqdm.update() - + try: + if shared.state.interrupted: + break + image = Image.open(os.path.join(folder, file)).convert('RGB') + try: + prompt = interrogate(image, mode, caption=captions[idx]) + writer.add(file, prompt) + except IndexError as e: + print(f" {e}, continuing") + continue + except OSError as e: + print(f" {e}, continuing") + continue + finally: + shared.total_tqdm.update() writer.close() ci.config.quiet = False unload() @@ -239,7 +255,7 @@ def batch_tab(): clip_model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model') mode = gr.Radio(['caption', 'best', 'fast', 'classic', 'negative'], label='Prompt Mode', value='fast') output_mode = gr.Dropdown(BATCH_OUTPUT_MODES, value=BATCH_OUTPUT_MODES[0], label='Output Mode') - with gr.Row(): + with gr.Row(): button = gr.Button("Go!", variant='primary') interrupt = gr.Button('Interrupt', visible=True) interrupt.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[])