Intergrate CLIP

This commit is contained in:
layerdiffusion
2024-08-01 12:24:35 -07:00
parent af0b04cc16
commit 4d1be42975
8 changed files with 172 additions and 44 deletions

14
backend/clip.py Normal file
View File

@@ -0,0 +1,14 @@
import torch
class JointTokenizer:
def __init__(self, huggingface_components):
self.clip_l = huggingface_components.get('tokenizer', None)
self.clip_g = huggingface_components.get('tokenizer_2', None)
class JointCLIP(torch.nn.Module):
def __init__(self, huggingface_components):
super().__init__()
self.clip_l = huggingface_components.get('text_encoder', None)
self.clip_g = huggingface_components.get('text_encoder_2', None)

View File

@@ -3,19 +3,59 @@ import importlib
from diffusers.loaders.single_file_utils import fetch_diffusers_config
from diffusers import DiffusionPipeline
from backend.vae import load_vae
from transformers import modeling_utils
from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace
from backend.operations import using_forge_operations
from backend.nn.autoencoder_kl import IntegratedAutoencoderKL
from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
dir_path = os.path.dirname(__file__)
def load_component(component_name, lib_name, cls_name, repo_path, sd):
def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
config_path = os.path.join(repo_path, component_name)
if component_name in ['scheduler', 'tokenizer']:
cls = getattr(importlib.import_module(lib_name), cls_name)
return cls.from_pretrained(os.path.join(repo_path, component_name))
if cls_name in ['AutoencoderKL']:
return load_vae(sd, config_path)
if component_name in ['feature_extractor', 'safety_checker']:
return None
if lib_name in ['transformers', 'diffusers']:
if component_name in ['scheduler'] or component_name.startswith('tokenizer'):
cls = getattr(importlib.import_module(lib_name), cls_name)
return cls.from_pretrained(os.path.join(repo_path, component_name))
if cls_name in ['AutoencoderKL']:
sd = try_filter_state_dict(state_dict, ['first_stage_model.', 'vae.'])
config = IntegratedAutoencoderKL.load_config(config_path)
with using_forge_operations():
model = IntegratedAutoencoderKL.from_config(config)
load_state_dict(model, sd)
return model
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
if component_name == 'text_encoder':
sd = try_filter_state_dict(state_dict, ['cond_stage_model.', 'conditioner.embedders.0.'])
elif component_name == 'text_encoder_2':
sd = try_filter_state_dict(state_dict, ['conditioner.embedders.1.'])
else:
raise ValueError(f"Wrong component_name: {component_name}")
if 'model.text_projection' in sd:
sd = transformers_convert(sd, "model.", "transformer.text_model.", 32)
sd = state_dict_key_replace(sd, {"model.text_projection": "text_projection",
"model.text_projection.weight": "text_projection",
"model.logit_scale": "logit_scale"})
config = CLIPTextConfig.from_pretrained(config_path)
with modeling_utils.no_init_weights():
with using_forge_operations():
model = IntegratedCLIP(config)
load_state_dict(model, sd, ignore_errors=['text_projection', 'logit_scale',
'transformer.text_model.embeddings.position_ids'])
return model
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
return None

11
backend/nn/clip.py Normal file
View File

@@ -0,0 +1,11 @@
import torch
from transformers import CLIPTextModel, CLIPTextConfig
class IntegratedCLIP(torch.nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.transformer = CLIPTextModel(config)
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))

View File

@@ -1,12 +1,95 @@
import torch
def filter_state_dict_with_prefix(sd, prefix):
def load_state_dict(model, sd, ignore_errors=[]):
missing, unexpected = model.load_state_dict(sd, strict=False)
missing = [x for x in missing if x not in ignore_errors]
unexpected = [x for x in unexpected if x not in ignore_errors]
if len(missing) > 0:
print(f'{type(model).__name__} Missing: {missing}')
if len(unexpected) > 0:
print(f'{type(model).__name__} Unexpected: {unexpected}')
return
def state_dict_has(sd, prefix):
return any(x.startswith(prefix) for x in sd.keys())
def filter_state_dict_with_prefix(sd, prefix, new_prefix=''):
new_sd = {}
for k, v in list(sd.items()):
if k.startswith(prefix):
new_sd[k[len(prefix):]] = v
new_sd[new_prefix + k[len(prefix):]] = v
del sd[k]
return new_sd
def try_filter_state_dict(sd, prefix_list, new_prefix=''):
for prefix in prefix_list:
if state_dict_has(sd, prefix):
return filter_state_dict_with_prefix(sd, prefix, new_prefix)
return {}
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight",
"{}token_embedding.weight": "{}embeddings.token_embedding.weight",
"{}ln_final.weight": "{}final_layer_norm.weight",
"{}ln_final.bias": "{}final_layer_norm.bias",
}
for k in keys_to_replace:
x = k.format(prefix_from)
if x in sd:
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
resblock_to_replace = {
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",
"mlp.c_fc": "mlp.fc1",
"mlp.c_proj": "mlp.fc2",
"attn.out_proj": "self_attn.out_proj",
}
for resblock in range(number):
for x in resblock_to_replace:
for y in ["weight", "bias"]:
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
if k in sd:
sd[k_to] = sd.pop(k)
for y in ["weight", "bias"]:
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
if k_from in sd:
weights = sd.pop(k_from)
shape_from = weights.shape[0] // 3
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd
def state_dict_key_replace(state_dict, keys_to_replace):
for x in keys_to_replace:
if x in state_dict:
state_dict[keys_to_replace[x]] = state_dict.pop(x)
return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
if filter_keys:
out = {}
else:
out = state_dict
for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
for x in replace:
w = state_dict.pop(x[0])
out[x[1]] = w
return out

View File

@@ -1,14 +0,0 @@
from backend.state_dict import filter_state_dict_with_prefix
from backend.operations import using_forge_operations
from backend.nn.autoencoder_kl import IntegratedAutoencoderKL
def load_vae(state_dict, config_path):
config = IntegratedAutoencoderKL.load_config(config_path)
with using_forge_operations():
model = IntegratedAutoencoderKL.from_config(config)
vae_state_dict = filter_state_dict_with_prefix(state_dict, "first_stage_model.")
model.load_state_dict(vae_state_dict, strict=True)
return model

View File

@@ -352,7 +352,7 @@ class LoadedModel:
elif hasattr(m, "weight"):
m.to(self.device)
mem_counter += module_size(m)
print(f"[Memory Management] {flag} Loader Disabled for ", m)
print(f"[Memory Management] {flag} Loader Disabled for", type(m).__name__)
print(f"[Memory Management] Parameters Loaded to {flag} Stream (MB) = ", real_async_memory / (1024 * 1024))
print(f"[Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024))

View File

@@ -97,22 +97,22 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen
return model, clip
from backend.clip import JointCLIP, JointTokenizer
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False):
def __init__(self, huggingface_components, no_init=False):
if no_init:
return
params = target.params.copy()
clip = target.clip
tokenizer = target.tokenizer
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
params['dtype'] = model_management.text_encoder_dtype(load_device)
text_encoder_dtype = model_management.text_encoder_dtype(load_device)
self.cond_stage_model = clip(**(params))
self.cond_stage_model = JointCLIP(huggingface_components)
self.tokenizer = JointTokenizer(huggingface_components)
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device)
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None

View File

@@ -112,13 +112,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
vae = VAE(model=vae)
if output_clip:
w = WeightsLoader()
clip_target = model_config.clip_target()
if clip_target is not None:
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
sd = model_config.process_clip_state_dict(sd)
load_model_weights(w, sd)
clip = CLIP(huggingface_components)
left_over = sd.keys()
if len(left_over) > 0:
@@ -177,7 +171,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
embedder = conditioner.embedders[i]
typename = type(embedder).__name__
if typename == 'FrozenCLIPEmbedder': # SDXL Clip L
embedder.tokenizer = forge_objects.clip.tokenizer.clip_l.tokenizer
embedder.tokenizer = forge_objects.clip.tokenizer.clip_l
embedder.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
@@ -186,7 +180,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
conditioner.embedders[i] = embedder
text_cond_models.append(embedder)
elif typename == 'FrozenOpenCLIPEmbedder2': # SDXL Clip G
embedder.tokenizer = forge_objects.clip.tokenizer.clip_g.tokenizer
embedder.tokenizer = forge_objects.clip.tokenizer.clip_g
embedder.transformer = forge_objects.clip.cond_stage_model.clip_g.transformer
embedder.text_projection = forge_objects.clip.cond_stage_model.clip_g.text_projection
model_embeddings = embedder.transformer.text_model.embeddings
@@ -201,14 +195,14 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
else:
sd_model.cond_stage_model = conditioner
elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': # SD15 Clip
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l.tokenizer
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
model_embeddings.token_embedding, sd_hijack.model_hijack)
sd_model.cond_stage_model = forge_clip.CLIP_SD_15_L(sd_model.cond_stage_model, sd_hijack.model_hijack)
elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_h.tokenizer
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_h
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_h.transformer
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(