diff --git a/scripts/clip_interrogator_ext.py b/scripts/clip_interrogator_ext.py
index 9e1cb50..e062046 100644
--- a/scripts/clip_interrogator_ext.py
+++ b/scripts/clip_interrogator_ext.py
@@ -1,11 +1,13 @@
import gradio as gr
import open_clip
import clip_interrogator
+import torch
+
from clip_interrogator import Config, Interrogator
-from modules import devices, script_callbacks
+from modules import devices, script_callbacks, shared, lowvram
-__version__ = '0.0.1'
+__version__ = '0.0.2'
ci = None
@@ -13,7 +15,22 @@ def load(clip_model_name):
global ci
if ci is None:
print(f"Loading CLIP Interrogator {clip_interrogator.__version__}...")
- ci = Interrogator(Config(device=devices.get_optimal_device(), clip_model_name=clip_model_name))
+
+ low_vram = shared.cmd_opts.lowvram or shared.cmd_opts.medvram
+ if not low_vram and torch.cuda.is_available():
+ device = devices.get_optimal_device()
+ vram_total_mb = torch.cuda.get_device_properties(device).total_memory / (1024**2)
+ if vram_total_mb < 12*1024*1024:
+ low_vram = True
+ print(f" detected < 12GB VRAM, using low VRAM mode")
+
+ config = Config(device=devices.get_optimal_device(), clip_model_name=clip_model_name)
+ if low_vram:
+ config.blip_model_type = 'base'
+ config.blip_offload = True
+ config.chunk_size = 1024
+ config.flavor_intermediate_count = 1024
+ ci = Interrogator(config)
if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
@@ -24,8 +41,8 @@ def unload():
global ci
if ci is not None:
print("Offloading CLIP Interrogator...")
- ci.blip_model = ci.blip_model.to("cpu")
- ci.clip_model = ci.clip_model.to("cpu")
+ ci.blip_model = ci.blip_model.to(devices.cpu)
+ ci.clip_model = ci.clip_model.to(devices.cpu)
devices.torch_gc()
def get_models():
@@ -105,6 +122,11 @@ def about_tab():
gr.Markdown("If you have any issues please visit [CLIP Interrogator on Github](https://github.com/pharmapsychotic/clip-interrogator) and drop a star if you like it!")
gr.Markdown(f"
CLIP Interrogator version: {clip_interrogator.__version__}
Extension version: {__version__}")
+ if torch.cuda.is_available():
+ device = devices.get_optimal_device()
+ vram_total_mb = torch.cuda.get_device_properties(device).total_memory / (1024**2)
+ gr.Markdown(f"GPU VRAM: {vram_total_mb:.2f}MB")
+
def add_tab():
with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tab("Prompt"):