diff --git a/scripts/clip_interrogator_ext.py b/scripts/clip_interrogator_ext.py index 9e1cb50..e062046 100644 --- a/scripts/clip_interrogator_ext.py +++ b/scripts/clip_interrogator_ext.py @@ -1,11 +1,13 @@ import gradio as gr import open_clip import clip_interrogator +import torch + from clip_interrogator import Config, Interrogator -from modules import devices, script_callbacks +from modules import devices, script_callbacks, shared, lowvram -__version__ = '0.0.1' +__version__ = '0.0.2' ci = None @@ -13,7 +15,22 @@ def load(clip_model_name): global ci if ci is None: print(f"Loading CLIP Interrogator {clip_interrogator.__version__}...") - ci = Interrogator(Config(device=devices.get_optimal_device(), clip_model_name=clip_model_name)) + + low_vram = shared.cmd_opts.lowvram or shared.cmd_opts.medvram + if not low_vram and torch.cuda.is_available(): + device = devices.get_optimal_device() + vram_total_mb = torch.cuda.get_device_properties(device).total_memory / (1024**2) + if vram_total_mb < 12*1024*1024: + low_vram = True + print(f" detected < 12GB VRAM, using low VRAM mode") + + config = Config(device=devices.get_optimal_device(), clip_model_name=clip_model_name) + if low_vram: + config.blip_model_type = 'base' + config.blip_offload = True + config.chunk_size = 1024 + config.flavor_intermediate_count = 1024 + ci = Interrogator(config) if clip_model_name != ci.config.clip_model_name: ci.config.clip_model_name = clip_model_name ci.load_clip_model() @@ -24,8 +41,8 @@ def unload(): global ci if ci is not None: print("Offloading CLIP Interrogator...") - ci.blip_model = ci.blip_model.to("cpu") - ci.clip_model = ci.clip_model.to("cpu") + ci.blip_model = ci.blip_model.to(devices.cpu) + ci.clip_model = ci.clip_model.to(devices.cpu) devices.torch_gc() def get_models(): @@ -105,6 +122,11 @@ def about_tab(): gr.Markdown("If you have any issues please visit [CLIP Interrogator on Github](https://github.com/pharmapsychotic/clip-interrogator) and drop a star if you like it!") gr.Markdown(f"

CLIP Interrogator version: {clip_interrogator.__version__}
Extension version: {__version__}") + if torch.cuda.is_available(): + device = devices.get_optimal_device() + vram_total_mb = torch.cuda.get_device_properties(device).total_memory / (1024**2) + gr.Markdown(f"GPU VRAM: {vram_total_mb:.2f}MB") + def add_tab(): with gr.Blocks(analytics_enabled=False) as ui: with gr.Tab("Prompt"):