diff --git a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py index f842f52a..e35b8d60 100644 --- a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py +++ b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py @@ -188,9 +188,8 @@ def zeroed_hidden_states(clip_vision, batch_size): image = torch.zeros([batch_size, 224, 224, 3]) ldm_patched.modules.model_management.load_model_gpu(clip_vision.patcher) pixel_values = clip_preprocess(image.to(clip_vision.load_device)).float() - outputs = clip_vision.model(pixel_values=pixel_values, intermediate_output=-2) - # we only need the penultimate hidden states - outputs = outputs[1].to(ldm_patched.modules.model_management.intermediate_device()) + outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True) + outputs = outputs.hidden_states[-2].to(ldm_patched.modules.model_management.intermediate_device()) return outputs def min_(tensor_list): diff --git a/ldm_patched/modules/clip_vision.py b/ldm_patched/modules/clip_vision.py index 80355887..83c89173 100644 --- a/ldm_patched/modules/clip_vision.py +++ b/ldm_patched/modules/clip_vision.py @@ -1,16 +1,20 @@ -# Taken from https://github.com/comfyanonymous/ComfyUI +# 1st edit by https://github.com/comfyanonymous/ComfyUI +# 2nd edit by Forge from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import os import torch -import json import ldm_patched.modules.ops import ldm_patched.modules.model_patcher import ldm_patched.modules.model_management import ldm_patched.modules.utils import ldm_patched.modules.clip_model +import ldm_patched.modules.ops as ops + +from transformers import modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection + class Output: def __getitem__(self, key): @@ -33,16 +37,26 @@ def clip_preprocess(image, size=224): class ClipVisionModel(): def __init__(self, json_config): - with open(json_config) as f: - config = json.load(f) + config = CLIPVisionConfig.from_json_file(json_config) self.load_device = ldm_patched.modules.model_management.text_encoder_device() - offload_device = ldm_patched.modules.model_management.text_encoder_offload_device() - self.dtype = ldm_patched.modules.model_management.text_encoder_dtype(self.load_device) - self.model = ldm_patched.modules.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ldm_patched.modules.ops.manual_cast) - self.model.eval() + self.offload_device = ldm_patched.modules.model_management.text_encoder_offload_device() - self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False): + self.dtype = torch.float16 + else: + self.dtype = torch.float32 + + with ops.use_patched_ops(ops.manual_cast): + with modeling_utils.no_init_weights(): + self.model = CLIPVisionModelWithProjection(config) + + self.model.to(self.dtype) + self.patcher = ldm_patched.modules.model_patcher.ModelPatcher( + self.model, + load_device=self.load_device, + offload_device=self.offload_device + ) def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) @@ -52,14 +66,15 @@ class ClipVisionModel(): def encode_image(self, image): ldm_patched.modules.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device)).float() - out = self.model(pixel_values=pixel_values, intermediate_output=-2) + pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device)) + outputs = self.model(pixel_values=pixel_values, output_hidden_states=True) - outputs = Output() - outputs["last_hidden_state"] = out[0].to(ldm_patched.modules.model_management.intermediate_device()) - outputs["image_embeds"] = out[2].to(ldm_patched.modules.model_management.intermediate_device()) - outputs["penultimate_hidden_states"] = out[1].to(ldm_patched.modules.model_management.intermediate_device()) - return outputs + o = Output() + o["last_hidden_state"] = outputs.last_hidden_state.to(ldm_patched.modules.model_management.intermediate_device()) + o["penultimate_hidden_states"] = outputs.hidden_states[-2].to(ldm_patched.modules.model_management.intermediate_device()) + o["image_embeds"] = outputs.image_embeds.to(ldm_patched.modules.model_management.intermediate_device()) + + return o def convert_to_transformers(sd, prefix): sd_k = sd.keys()