mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
try solve saturation problems for instant id in #155
This commit is contained in:
@@ -188,9 +188,8 @@ def zeroed_hidden_states(clip_vision, batch_size):
|
||||
image = torch.zeros([batch_size, 224, 224, 3])
|
||||
ldm_patched.modules.model_management.load_model_gpu(clip_vision.patcher)
|
||||
pixel_values = clip_preprocess(image.to(clip_vision.load_device)).float()
|
||||
outputs = clip_vision.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
# we only need the penultimate hidden states
|
||||
outputs = outputs[1].to(ldm_patched.modules.model_management.intermediate_device())
|
||||
outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True)
|
||||
outputs = outputs.hidden_states[-2].to(ldm_patched.modules.model_management.intermediate_device())
|
||||
return outputs
|
||||
|
||||
def min_(tensor_list):
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
# Taken from https://github.com/comfyanonymous/ComfyUI
|
||||
# 1st edit by https://github.com/comfyanonymous/ComfyUI
|
||||
# 2nd edit by Forge
|
||||
|
||||
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
|
||||
import ldm_patched.modules.ops
|
||||
import ldm_patched.modules.model_patcher
|
||||
import ldm_patched.modules.model_management
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.modules.clip_model
|
||||
import ldm_patched.modules.ops as ops
|
||||
|
||||
from transformers import modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection
|
||||
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@@ -33,16 +37,26 @@ def clip_preprocess(image, size=224):
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
config = CLIPVisionConfig.from_json_file(json_config)
|
||||
|
||||
self.load_device = ldm_patched.modules.model_management.text_encoder_device()
|
||||
offload_device = ldm_patched.modules.model_management.text_encoder_offload_device()
|
||||
self.dtype = ldm_patched.modules.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = ldm_patched.modules.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ldm_patched.modules.ops.manual_cast)
|
||||
self.model.eval()
|
||||
self.offload_device = ldm_patched.modules.model_management.text_encoder_offload_device()
|
||||
|
||||
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.float32
|
||||
|
||||
with ops.use_patched_ops(ops.manual_cast):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.model = CLIPVisionModelWithProjection(config)
|
||||
|
||||
self.model.to(self.dtype)
|
||||
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(
|
||||
self.model,
|
||||
load_device=self.load_device,
|
||||
offload_device=self.offload_device
|
||||
)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
@@ -52,14 +66,15 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image):
|
||||
ldm_patched.modules.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device)).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device))
|
||||
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
|
||||
|
||||
outputs = Output()
|
||||
outputs["last_hidden_state"] = out[0].to(ldm_patched.modules.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(ldm_patched.modules.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = out[1].to(ldm_patched.modules.model_management.intermediate_device())
|
||||
return outputs
|
||||
o = Output()
|
||||
o["last_hidden_state"] = outputs.last_hidden_state.to(ldm_patched.modules.model_management.intermediate_device())
|
||||
o["penultimate_hidden_states"] = outputs.hidden_states[-2].to(ldm_patched.modules.model_management.intermediate_device())
|
||||
o["image_embeds"] = outputs.image_embeds.to(ldm_patched.modules.model_management.intermediate_device())
|
||||
|
||||
return o
|
||||
|
||||
def convert_to_transformers(sd, prefix):
|
||||
sd_k = sd.keys()
|
||||
|
||||
Reference in New Issue
Block a user