Update interrogate.py

This commit is contained in:
lllyasviel
2024-01-25 12:41:18 -08:00
parent 95918b4f82
commit 92682b4dbf

View File

@@ -10,7 +10,10 @@ import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
from modules import devices, paths, shared, modelloader, errors
from ldm_patched.modules import model_management
from ldm_patched.modules.model_patcher import ModelPatcher
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -53,7 +56,16 @@ class InterrogateModels:
self.loaded_categories = None
self.skip_categories = []
self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()
self.dtype = torch.float32
if model_management.should_use_fp16(device=self.load_device):
self.dtype = torch.float16
self.blip_patcher = None
self.clip_patcher = None
def categories(self):
if not os.path.exists(self.content_dir):
@@ -119,35 +131,25 @@ class InterrogateModels:
def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()
if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(devices.device_interrogate)
self.blip_model = self.blip_model.to(device=self.offload_device, dtype=self.dtype)
self.blip_patcher = ModelPatcher(self.blip_model, load_device=self.load_device, offload_device=self.offload_device)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(device=self.offload_device, dtype=self.dtype)
self.clip_patcher = ModelPatcher(self.clip_model, load_device=self.load_device, offload_device=self.offload_device)
self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = torch_utils.get_param(self.clip_model).dtype
model_management.load_models_gpu([self.blip_patcher, self.clip_patcher])
return
def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
if self.clip_model is not None:
self.clip_model = self.clip_model.to(devices.cpu)
pass
def send_blip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
if self.blip_model is not None:
self.blip_model = self.blip_model.to(devices.cpu)
pass
def unload(self):
self.send_clip_to_ram()
self.send_blip_to_ram()
devices.torch_gc()
pass
def rank(self, image_features, text_array, top_count=1):
import clip
@@ -186,9 +188,6 @@ class InterrogateModels:
res = ""
shared.state.begin(job="interrogate")
try:
lowvram.send_everything_to_cpu()
devices.torch_gc()
self.load()
caption = self.generate_caption(pil_image)