Update patch_clip.py

This commit is contained in:
lllyasviel
2024-01-25 23:36:12 -08:00
parent 7147972e80
commit 5ac1da9fe8

View File

@@ -19,7 +19,7 @@ import ldm_patched.modules.clip_vision
import ldm_patched.modules.ops as ops
from modules_forge.ops import use_patched_ops
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils
def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last", layer_idx=None,
@@ -106,49 +106,7 @@ def patched_SDClipModel_forward(self, tokens):
return z.float(), pooled_output
def patched_ClipVisionModel__init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
self.load_device = ldm_patched.modules.model_management.text_encoder_device()
self.offload_device = ldm_patched.modules.model_management.text_encoder_offload_device()
if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
self.dtype = torch.float16
else:
self.dtype = torch.float32
with use_patched_ops(ops.manual_cast):
with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(
self.model,
load_device=self.load_device,
offload_device=self.offload_device
)
def patched_ClipVisionModel_encode_image(self, image):
ldm_patched.modules.model_management.load_model_gpu(self.patcher)
pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device))
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
for k in outputs:
t = outputs[k]
if t is not None:
if k == 'hidden_states':
outputs["penultimate_hidden_states"] = t[-2].to(ldm_patched.modules.model_management.intermediate_device())
outputs["hidden_states"] = None
else:
outputs[k] = t.to(ldm_patched.modules.model_management.intermediate_device())
return outputs
def patch_all_clip():
ldm_patched.modules.sd1_clip.SDClipModel.__init__ = patched_SDClipModel__init__
ldm_patched.modules.sd1_clip.SDClipModel.forward = patched_SDClipModel_forward
ldm_patched.modules.clip_vision.ClipVisionModel.__init__ = patched_ClipVisionModel__init__
ldm_patched.modules.clip_vision.ClipVisionModel.encode_image = patched_ClipVisionModel_encode_image
return