mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-05 04:59:49 +00:00
Intergrate CLIP
This commit is contained in:
14
backend/clip.py
Normal file
14
backend/clip.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
|
||||
|
||||
class JointTokenizer:
|
||||
def __init__(self, huggingface_components):
|
||||
self.clip_l = huggingface_components.get('tokenizer', None)
|
||||
self.clip_g = huggingface_components.get('tokenizer_2', None)
|
||||
|
||||
|
||||
class JointCLIP(torch.nn.Module):
|
||||
def __init__(self, huggingface_components):
|
||||
super().__init__()
|
||||
self.clip_l = huggingface_components.get('text_encoder', None)
|
||||
self.clip_g = huggingface_components.get('text_encoder_2', None)
|
||||
@@ -3,19 +3,59 @@ import importlib
|
||||
|
||||
from diffusers.loaders.single_file_utils import fetch_diffusers_config
|
||||
from diffusers import DiffusionPipeline
|
||||
from backend.vae import load_vae
|
||||
|
||||
from transformers import modeling_utils
|
||||
from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.autoencoder_kl import IntegratedAutoencoderKL
|
||||
from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
|
||||
|
||||
dir_path = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def load_component(component_name, lib_name, cls_name, repo_path, sd):
|
||||
def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
|
||||
config_path = os.path.join(repo_path, component_name)
|
||||
if component_name in ['scheduler', 'tokenizer']:
|
||||
cls = getattr(importlib.import_module(lib_name), cls_name)
|
||||
return cls.from_pretrained(os.path.join(repo_path, component_name))
|
||||
if cls_name in ['AutoencoderKL']:
|
||||
return load_vae(sd, config_path)
|
||||
|
||||
if component_name in ['feature_extractor', 'safety_checker']:
|
||||
return None
|
||||
|
||||
if lib_name in ['transformers', 'diffusers']:
|
||||
if component_name in ['scheduler'] or component_name.startswith('tokenizer'):
|
||||
cls = getattr(importlib.import_module(lib_name), cls_name)
|
||||
return cls.from_pretrained(os.path.join(repo_path, component_name))
|
||||
if cls_name in ['AutoencoderKL']:
|
||||
sd = try_filter_state_dict(state_dict, ['first_stage_model.', 'vae.'])
|
||||
config = IntegratedAutoencoderKL.load_config(config_path)
|
||||
|
||||
with using_forge_operations():
|
||||
model = IntegratedAutoencoderKL.from_config(config)
|
||||
|
||||
load_state_dict(model, sd)
|
||||
return model
|
||||
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
|
||||
if component_name == 'text_encoder':
|
||||
sd = try_filter_state_dict(state_dict, ['cond_stage_model.', 'conditioner.embedders.0.'])
|
||||
elif component_name == 'text_encoder_2':
|
||||
sd = try_filter_state_dict(state_dict, ['conditioner.embedders.1.'])
|
||||
else:
|
||||
raise ValueError(f"Wrong component_name: {component_name}")
|
||||
|
||||
if 'model.text_projection' in sd:
|
||||
sd = transformers_convert(sd, "model.", "transformer.text_model.", 32)
|
||||
sd = state_dict_key_replace(sd, {"model.text_projection": "text_projection",
|
||||
"model.text_projection.weight": "text_projection",
|
||||
"model.logit_scale": "logit_scale"})
|
||||
|
||||
config = CLIPTextConfig.from_pretrained(config_path)
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations():
|
||||
model = IntegratedCLIP(config)
|
||||
|
||||
load_state_dict(model, sd, ignore_errors=['text_projection', 'logit_scale',
|
||||
'transformer.text_model.embeddings.position_ids'])
|
||||
return model
|
||||
|
||||
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
||||
return None
|
||||
|
||||
|
||||
|
||||
11
backend/nn/clip.py
Normal file
11
backend/nn/clip.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import torch
|
||||
|
||||
from transformers import CLIPTextModel, CLIPTextConfig
|
||||
|
||||
|
||||
class IntegratedCLIP(torch.nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.transformer = CLIPTextModel(config)
|
||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
@@ -1,12 +1,95 @@
|
||||
import torch
|
||||
|
||||
|
||||
def filter_state_dict_with_prefix(sd, prefix):
|
||||
def load_state_dict(model, sd, ignore_errors=[]):
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||||
missing = [x for x in missing if x not in ignore_errors]
|
||||
unexpected = [x for x in unexpected if x not in ignore_errors]
|
||||
if len(missing) > 0:
|
||||
print(f'{type(model).__name__} Missing: {missing}')
|
||||
if len(unexpected) > 0:
|
||||
print(f'{type(model).__name__} Unexpected: {unexpected}')
|
||||
return
|
||||
|
||||
|
||||
def state_dict_has(sd, prefix):
|
||||
return any(x.startswith(prefix) for x in sd.keys())
|
||||
|
||||
|
||||
def filter_state_dict_with_prefix(sd, prefix, new_prefix=''):
|
||||
new_sd = {}
|
||||
|
||||
for k, v in list(sd.items()):
|
||||
if k.startswith(prefix):
|
||||
new_sd[k[len(prefix):]] = v
|
||||
new_sd[new_prefix + k[len(prefix):]] = v
|
||||
del sd[k]
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def try_filter_state_dict(sd, prefix_list, new_prefix=''):
|
||||
for prefix in prefix_list:
|
||||
if state_dict_has(sd, prefix):
|
||||
return filter_state_dict_with_prefix(sd, prefix, new_prefix)
|
||||
return {}
|
||||
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
"{}token_embedding.weight": "{}embeddings.token_embedding.weight",
|
||||
"{}ln_final.weight": "{}final_layer_norm.weight",
|
||||
"{}ln_final.bias": "{}final_layer_norm.bias",
|
||||
}
|
||||
|
||||
for k in keys_to_replace:
|
||||
x = k.format(prefix_from)
|
||||
if x in sd:
|
||||
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
|
||||
|
||||
resblock_to_replace = {
|
||||
"ln_1": "layer_norm1",
|
||||
"ln_2": "layer_norm2",
|
||||
"mlp.c_fc": "mlp.fc1",
|
||||
"mlp.c_proj": "mlp.fc2",
|
||||
"attn.out_proj": "self_attn.out_proj",
|
||||
}
|
||||
|
||||
for resblock in range(number):
|
||||
for x in resblock_to_replace:
|
||||
for y in ["weight", "bias"]:
|
||||
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
if k in sd:
|
||||
sd[k_to] = sd.pop(k)
|
||||
|
||||
for y in ["weight", "bias"]:
|
||||
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||
if k_from in sd:
|
||||
weights = sd.pop(k_from)
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return sd
|
||||
|
||||
|
||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||
for x in keys_to_replace:
|
||||
if x in state_dict:
|
||||
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
||||
return state_dict
|
||||
|
||||
|
||||
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||
if filter_keys:
|
||||
out = {}
|
||||
else:
|
||||
out = state_dict
|
||||
for rp in replace_prefix:
|
||||
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
||||
for x in replace:
|
||||
w = state_dict.pop(x[0])
|
||||
out[x[1]] = w
|
||||
return out
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from backend.state_dict import filter_state_dict_with_prefix
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.autoencoder_kl import IntegratedAutoencoderKL
|
||||
|
||||
|
||||
def load_vae(state_dict, config_path):
|
||||
config = IntegratedAutoencoderKL.load_config(config_path)
|
||||
|
||||
with using_forge_operations():
|
||||
model = IntegratedAutoencoderKL.from_config(config)
|
||||
|
||||
vae_state_dict = filter_state_dict_with_prefix(state_dict, "first_stage_model.")
|
||||
model.load_state_dict(vae_state_dict, strict=True)
|
||||
return model
|
||||
@@ -352,7 +352,7 @@ class LoadedModel:
|
||||
elif hasattr(m, "weight"):
|
||||
m.to(self.device)
|
||||
mem_counter += module_size(m)
|
||||
print(f"[Memory Management] {flag} Loader Disabled for ", m)
|
||||
print(f"[Memory Management] {flag} Loader Disabled for", type(m).__name__)
|
||||
print(f"[Memory Management] Parameters Loaded to {flag} Stream (MB) = ", real_async_memory / (1024 * 1024))
|
||||
print(f"[Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024))
|
||||
|
||||
|
||||
@@ -97,22 +97,22 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen
|
||||
return model, clip
|
||||
|
||||
|
||||
from backend.clip import JointCLIP, JointTokenizer
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False):
|
||||
def __init__(self, huggingface_components, no_init=False):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
clip = target.clip
|
||||
tokenizer = target.tokenizer
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = offload_device
|
||||
params['dtype'] = model_management.text_encoder_dtype(load_device)
|
||||
text_encoder_dtype = model_management.text_encoder_dtype(load_device)
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
self.cond_stage_model = JointCLIP(huggingface_components)
|
||||
self.tokenizer = JointTokenizer(huggingface_components)
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||
self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device)
|
||||
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.layer_idx = None
|
||||
|
||||
|
||||
@@ -112,13 +112,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
vae = VAE(model=vae)
|
||||
|
||||
if output_clip:
|
||||
w = WeightsLoader()
|
||||
clip_target = model_config.clip_target()
|
||||
if clip_target is not None:
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
sd = model_config.process_clip_state_dict(sd)
|
||||
load_model_weights(w, sd)
|
||||
clip = CLIP(huggingface_components)
|
||||
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
@@ -177,7 +171,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
||||
embedder = conditioner.embedders[i]
|
||||
typename = type(embedder).__name__
|
||||
if typename == 'FrozenCLIPEmbedder': # SDXL Clip L
|
||||
embedder.tokenizer = forge_objects.clip.tokenizer.clip_l.tokenizer
|
||||
embedder.tokenizer = forge_objects.clip.tokenizer.clip_l
|
||||
embedder.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
|
||||
model_embeddings = embedder.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
|
||||
@@ -186,7 +180,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
||||
conditioner.embedders[i] = embedder
|
||||
text_cond_models.append(embedder)
|
||||
elif typename == 'FrozenOpenCLIPEmbedder2': # SDXL Clip G
|
||||
embedder.tokenizer = forge_objects.clip.tokenizer.clip_g.tokenizer
|
||||
embedder.tokenizer = forge_objects.clip.tokenizer.clip_g
|
||||
embedder.transformer = forge_objects.clip.cond_stage_model.clip_g.transformer
|
||||
embedder.text_projection = forge_objects.clip.cond_stage_model.clip_g.text_projection
|
||||
model_embeddings = embedder.transformer.text_model.embeddings
|
||||
@@ -201,14 +195,14 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
||||
else:
|
||||
sd_model.cond_stage_model = conditioner
|
||||
elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': # SD15 Clip
|
||||
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l.tokenizer
|
||||
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l
|
||||
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
|
||||
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
|
||||
model_embeddings.token_embedding, sd_hijack.model_hijack)
|
||||
sd_model.cond_stage_model = forge_clip.CLIP_SD_15_L(sd_model.cond_stage_model, sd_hijack.model_hijack)
|
||||
elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip
|
||||
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_h.tokenizer
|
||||
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_h
|
||||
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_h.transformer
|
||||
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
|
||||
|
||||
Reference in New Issue
Block a user