adds try/catch/finally blocks to handle bad images

This commit is contained in:
genevera
2023-05-31 16:27:50 -04:00
parent 9e6bbd9b89
commit 39904fe789

View File

@@ -4,7 +4,7 @@ import open_clip
import os
import torch
from PIL import Image
from PIL import Image, UnidentifiedImageError
import clip_interrogator
from clip_interrogator import Config, Interrogator
@@ -55,7 +55,7 @@ def load(clip_model_name):
print(f"Loading CLIP Interrogator {clip_interrogator.__version__}...")
config = Config(
device=devices.get_optimal_device(),
device=devices.get_optimal_device(),
cache_path = 'models/clip-interrogator',
clip_model_name=clip_model_name,
blip_model=shared.interrogator.load_blip_model().float()
@@ -95,7 +95,7 @@ def image_analysis(image, clip_model_name):
movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
def interrogate(image, mode, caption=None):
@@ -117,7 +117,7 @@ def image_to_prompt(image, mode, clip_model_name):
shared.state.begin()
shared.state.job = 'interrogate'
try:
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -169,13 +169,14 @@ def analyze_tab():
model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
with gr.Row():
medium = gr.Label(label="Medium", num_top_classes=5)
artist = gr.Label(label="Artist", num_top_classes=5)
artist = gr.Label(label="Artist", num_top_classes=5)
movement = gr.Label(label="Movement", num_top_classes=5)
trending = gr.Label(label="Trending", num_top_classes=5)
flavor = gr.Label(label="Flavor", num_top_classes=5)
button = gr.Button("Analyze", variant='primary')
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
def batch_tab():
def batch_process(folder, clip_model, mode, output_mode):
if not os.path.exists(folder):
@@ -190,7 +191,7 @@ def batch_tab():
shared.state.begin()
shared.state.job = 'batch interrogate'
try:
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -202,25 +203,40 @@ def batch_tab():
# generate captions in first pass
captions = []
for file in files:
if shared.state.interrupted:
break
image = Image.open(os.path.join(folder, file)).convert('RGB')
captions.append(ci.generate_caption(image))
shared.total_tqdm.update()
try:
if shared.state.interrupted:
break
image = Image.open(os.path.join(folder, file)).convert('RGB')
captions.append(ci.generate_caption(image))
shared.total_tqdm.update()
except OSError:
print(f" Could not read {file}; continuing")
continue
finally:
shared.total_tqdm.update()
# interrogate in second pass
writer = BatchWriter(folder, output_mode)
shared.total_tqdm.clear()
shared.total_tqdm.updateTotal(len(files))
for idx, file in enumerate(files):
if shared.state.interrupted:
break
image = Image.open(os.path.join(folder, file)).convert('RGB')
prompt = interrogate(image, mode, caption=captions[idx])
writer.add(file, prompt)
shared.total_tqdm.update()
try:
if shared.state.interrupted:
break
image = Image.open(os.path.join(folder, file)).convert('RGB')
try:
prompt = interrogate(image, mode, caption=captions[idx])
writer.add(file, prompt)
except IndexError as e:
print(f" {e}, continuing")
continue
except OSError as e:
print(f" {e}, continuing")
continue
finally:
shared.total_tqdm.update()
writer.close()
ci.config.quiet = False
unload()
@@ -239,7 +255,7 @@ def batch_tab():
clip_model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
mode = gr.Radio(['caption', 'best', 'fast', 'classic', 'negative'], label='Prompt Mode', value='fast')
output_mode = gr.Dropdown(BATCH_OUTPUT_MODES, value=BATCH_OUTPUT_MODES[0], label='Output Mode')
with gr.Row():
with gr.Row():
button = gr.Button("Go!", variant='primary')
interrupt = gr.Button('Interrupt', visible=True)
interrupt.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[])