mirror of
https://github.com/pharmapsychotic/clip-interrogator-ext.git
synced 2026-04-28 10:11:27 +00:00
Merge pull request #49 from genevera/handle_unreadable_files
adds try/catch/finally blocks to handle bad images
This commit is contained in:
@@ -11,7 +11,7 @@ from clip_interrogator import Config, Interrogator
|
||||
|
||||
from modules import devices, lowvram, script_callbacks, shared
|
||||
|
||||
__version__ = '0.1.4'
|
||||
__version__ = '0.1.5'
|
||||
|
||||
ci = None
|
||||
low_vram = False
|
||||
@@ -55,7 +55,7 @@ def load(clip_model_name):
|
||||
print(f"Loading CLIP Interrogator {clip_interrogator.__version__}...")
|
||||
|
||||
config = Config(
|
||||
device=devices.get_optimal_device(),
|
||||
device=devices.get_optimal_device(),
|
||||
cache_path = 'models/clip-interrogator',
|
||||
clip_model_name=clip_model_name,
|
||||
blip_model=shared.interrogator.load_blip_model().float()
|
||||
@@ -95,7 +95,7 @@ def image_analysis(image, clip_model_name):
|
||||
movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
|
||||
trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
|
||||
flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
|
||||
|
||||
|
||||
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
|
||||
|
||||
def interrogate(image, mode, caption=None):
|
||||
@@ -117,7 +117,7 @@ def image_to_prompt(image, mode, clip_model_name):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
|
||||
try:
|
||||
try:
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
@@ -169,28 +169,32 @@ def analyze_tab():
|
||||
model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
|
||||
with gr.Row():
|
||||
medium = gr.Label(label="Medium", num_top_classes=5)
|
||||
artist = gr.Label(label="Artist", num_top_classes=5)
|
||||
artist = gr.Label(label="Artist", num_top_classes=5)
|
||||
movement = gr.Label(label="Movement", num_top_classes=5)
|
||||
trending = gr.Label(label="Trending", num_top_classes=5)
|
||||
flavor = gr.Label(label="Flavor", num_top_classes=5)
|
||||
button = gr.Button("Analyze", variant='primary')
|
||||
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
|
||||
|
||||
|
||||
def batch_tab():
|
||||
def batch_process(folder, clip_model, mode, output_mode):
|
||||
if not os.path.exists(folder):
|
||||
return f"Folder {folder} does not exist"
|
||||
print(f"Folder {folder} does not exist")
|
||||
return
|
||||
if not os.path.isdir(folder):
|
||||
return "{folder} is not a directory"
|
||||
print("{folder} is not a directory")
|
||||
return
|
||||
|
||||
files = [f for f in os.listdir(folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||||
if not files:
|
||||
return "Folder has no images"
|
||||
print("Folder has no images")
|
||||
return
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'batch interrogate'
|
||||
|
||||
try:
|
||||
try:
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
@@ -202,25 +206,35 @@ def batch_tab():
|
||||
|
||||
# generate captions in first pass
|
||||
captions = []
|
||||
|
||||
for file in files:
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
image = Image.open(os.path.join(folder, file)).convert('RGB')
|
||||
captions.append(ci.generate_caption(image))
|
||||
shared.total_tqdm.update()
|
||||
try:
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
image = Image.open(os.path.join(folder, file)).convert('RGB')
|
||||
caption = ci.generate_caption(image)
|
||||
except OSError as e:
|
||||
print(f"{e}; continuing")
|
||||
caption = ""
|
||||
finally:
|
||||
captions.append(caption)
|
||||
shared.total_tqdm.update()
|
||||
|
||||
# interrogate in second pass
|
||||
writer = BatchWriter(folder, output_mode)
|
||||
shared.total_tqdm.clear()
|
||||
shared.total_tqdm.updateTotal(len(files))
|
||||
for idx, file in enumerate(files):
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
image = Image.open(os.path.join(folder, file)).convert('RGB')
|
||||
prompt = interrogate(image, mode, caption=captions[idx])
|
||||
writer.add(file, prompt)
|
||||
shared.total_tqdm.update()
|
||||
|
||||
try:
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
image = Image.open(os.path.join(folder, file)).convert('RGB')
|
||||
prompt = interrogate(image, mode, caption=captions[idx])
|
||||
writer.add(file, prompt)
|
||||
except OSError as e:
|
||||
print(f" {e}, continuing")
|
||||
finally:
|
||||
shared.total_tqdm.update()
|
||||
writer.close()
|
||||
ci.config.quiet = False
|
||||
unload()
|
||||
@@ -239,7 +253,7 @@ def batch_tab():
|
||||
clip_model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
|
||||
mode = gr.Radio(['caption', 'best', 'fast', 'classic', 'negative'], label='Prompt Mode', value='fast')
|
||||
output_mode = gr.Dropdown(BATCH_OUTPUT_MODES, value=BATCH_OUTPUT_MODES[0], label='Output Mode')
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
button = gr.Button("Go!", variant='primary')
|
||||
interrupt = gr.Button('Interrupt', visible=True)
|
||||
interrupt.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[])
|
||||
|
||||
Reference in New Issue
Block a user