From 92682b4dbf0fcaa3f7c6d539bf983be8cd154e93 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 12:41:18 -0800 Subject: [PATCH] Update interrogate.py --- modules/interrogate.py | 47 +++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 35a627ca..2d820bc7 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -10,7 +10,10 @@ import torch.hub from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils +from modules import devices, paths, shared, modelloader, errors +from ldm_patched.modules import model_management +from ldm_patched.modules.model_patcher import ModelPatcher + blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -53,7 +56,16 @@ class InterrogateModels: self.loaded_categories = None self.skip_categories = [] self.content_dir = content_dir - self.running_on_cpu = devices.device_interrogate == torch.device("cpu") + + self.load_device = model_management.text_encoder_device() + self.offload_device = model_management.text_encoder_offload_device() + self.dtype = torch.float32 + + if model_management.should_use_fp16(device=self.load_device): + self.dtype = torch.float16 + + self.blip_patcher = None + self.clip_patcher = None def categories(self): if not os.path.exists(self.content_dir): @@ -119,35 +131,25 @@ class InterrogateModels: def load(self): if self.blip_model is None: self.blip_model = self.load_blip_model() - if not shared.cmd_opts.no_half and not self.running_on_cpu: - self.blip_model = self.blip_model.half() - - self.blip_model = self.blip_model.to(devices.device_interrogate) + self.blip_model = self.blip_model.to(device=self.offload_device, dtype=self.dtype) + self.blip_patcher = ModelPatcher(self.blip_model, load_device=self.load_device, offload_device=self.offload_device) if self.clip_model is None: self.clip_model, self.clip_preprocess = self.load_clip_model() - if not shared.cmd_opts.no_half and not self.running_on_cpu: - self.clip_model = self.clip_model.half() + self.clip_model = self.clip_model.to(device=self.offload_device, dtype=self.dtype) + self.clip_patcher = ModelPatcher(self.clip_model, load_device=self.load_device, offload_device=self.offload_device) - self.clip_model = self.clip_model.to(devices.device_interrogate) - - self.dtype = torch_utils.get_param(self.clip_model).dtype + model_management.load_models_gpu([self.blip_patcher, self.clip_patcher]) + return def send_clip_to_ram(self): - if not shared.opts.interrogate_keep_models_in_memory: - if self.clip_model is not None: - self.clip_model = self.clip_model.to(devices.cpu) + pass def send_blip_to_ram(self): - if not shared.opts.interrogate_keep_models_in_memory: - if self.blip_model is not None: - self.blip_model = self.blip_model.to(devices.cpu) + pass def unload(self): - self.send_clip_to_ram() - self.send_blip_to_ram() - - devices.torch_gc() + pass def rank(self, image_features, text_array, top_count=1): import clip @@ -186,9 +188,6 @@ class InterrogateModels: res = "" shared.state.begin(job="interrogate") try: - lowvram.send_everything_to_cpu() - devices.torch_gc() - self.load() caption = self.generate_caption(pil_image)