diff --git a/modules_forge/patch_clip.py b/modules_forge/patch_clip.py index de599089..3cdbdac7 100644 --- a/modules_forge/patch_clip.py +++ b/modules_forge/patch_clip.py @@ -19,7 +19,7 @@ import ldm_patched.modules.clip_vision import ldm_patched.modules.ops as ops from modules_forge.ops import use_patched_ops -from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection +from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last", layer_idx=None, @@ -106,49 +106,7 @@ def patched_SDClipModel_forward(self, tokens): return z.float(), pooled_output -def patched_ClipVisionModel__init__(self, json_config): - config = CLIPVisionConfig.from_json_file(json_config) - - self.load_device = ldm_patched.modules.model_management.text_encoder_device() - self.offload_device = ldm_patched.modules.model_management.text_encoder_offload_device() - - if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False): - self.dtype = torch.float16 - else: - self.dtype = torch.float32 - - with use_patched_ops(ops.manual_cast): - with modeling_utils.no_init_weights(): - self.model = CLIPVisionModelWithProjection(config) - - self.model.to(self.dtype) - self.patcher = ldm_patched.modules.model_patcher.ModelPatcher( - self.model, - load_device=self.load_device, - offload_device=self.offload_device - ) - - -def patched_ClipVisionModel_encode_image(self, image): - ldm_patched.modules.model_management.load_model_gpu(self.patcher) - pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device)) - outputs = self.model(pixel_values=pixel_values, output_hidden_states=True) - - for k in outputs: - t = outputs[k] - if t is not None: - if k == 'hidden_states': - outputs["penultimate_hidden_states"] = t[-2].to(ldm_patched.modules.model_management.intermediate_device()) - outputs["hidden_states"] = None - else: - outputs[k] = t.to(ldm_patched.modules.model_management.intermediate_device()) - - return outputs - - def patch_all_clip(): ldm_patched.modules.sd1_clip.SDClipModel.__init__ = patched_SDClipModel__init__ ldm_patched.modules.sd1_clip.SDClipModel.forward = patched_SDClipModel_forward - ldm_patched.modules.clip_vision.ClipVisionModel.__init__ = patched_ClipVisionModel__init__ - ldm_patched.modules.clip_vision.ClipVisionModel.encode_image = patched_ClipVisionModel_encode_image return