From 2c60a2b97379994d95d4c48bd791fb337ef074a1 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 12:49:01 -0800 Subject: [PATCH] Update interrogate.py --- modules/interrogate.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 2d820bc7..43365a83 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -118,13 +118,8 @@ class InterrogateModels: def load_clip_model(self): import clip - if self.running_on_cpu: - model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path) - else: - model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path) - + model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path) model.eval() - model = model.to(devices.device_interrogate) return model, preprocess @@ -160,11 +155,11 @@ class InterrogateModels: text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] top_count = min(top_count, len(text_array)) - text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate) + text_tokens = clip.tokenize(list(text_array), truncate=True).to(self.load_device) text_features = self.clip_model.encode_text(text_tokens).type(self.dtype) text_features /= text_features.norm(dim=-1, keepdim=True) - similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate) + similarity = torch.zeros((1, len(text_array))).to(self.load_device) for i in range(image_features.shape[0]): similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) similarity /= image_features.shape[0] @@ -177,7 +172,7 @@ class InterrogateModels: transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate) + ])(pil_image).unsqueeze(0).type(self.dtype).to(self.load_device) with torch.no_grad(): caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) @@ -196,7 +191,7 @@ class InterrogateModels: res = caption - clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate) + clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(self.load_device) with torch.no_grad(), devices.autocast(): image_features = self.clip_model.encode_image(clip_image).type(self.dtype)