diff --git a/backend/clip.py b/backend/clip.py new file mode 100644 index 00000000..7adbded3 --- /dev/null +++ b/backend/clip.py @@ -0,0 +1,14 @@ +import torch + + +class JointTokenizer: + def __init__(self, huggingface_components): + self.clip_l = huggingface_components.get('tokenizer', None) + self.clip_g = huggingface_components.get('tokenizer_2', None) + + +class JointCLIP(torch.nn.Module): + def __init__(self, huggingface_components): + super().__init__() + self.clip_l = huggingface_components.get('text_encoder', None) + self.clip_g = huggingface_components.get('text_encoder_2', None) diff --git a/backend/loader.py b/backend/loader.py index f513f391..2894b097 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -3,19 +3,59 @@ import importlib from diffusers.loaders.single_file_utils import fetch_diffusers_config from diffusers import DiffusionPipeline -from backend.vae import load_vae - +from transformers import modeling_utils +from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace +from backend.operations import using_forge_operations +from backend.nn.autoencoder_kl import IntegratedAutoencoderKL +from backend.nn.clip import IntegratedCLIP, CLIPTextConfig dir_path = os.path.dirname(__file__) -def load_component(component_name, lib_name, cls_name, repo_path, sd): +def load_component(component_name, lib_name, cls_name, repo_path, state_dict): config_path = os.path.join(repo_path, component_name) - if component_name in ['scheduler', 'tokenizer']: - cls = getattr(importlib.import_module(lib_name), cls_name) - return cls.from_pretrained(os.path.join(repo_path, component_name)) - if cls_name in ['AutoencoderKL']: - return load_vae(sd, config_path) + + if component_name in ['feature_extractor', 'safety_checker']: + return None + + if lib_name in ['transformers', 'diffusers']: + if component_name in ['scheduler'] or component_name.startswith('tokenizer'): + cls = getattr(importlib.import_module(lib_name), cls_name) + return cls.from_pretrained(os.path.join(repo_path, component_name)) + if cls_name in ['AutoencoderKL']: + sd = try_filter_state_dict(state_dict, ['first_stage_model.', 'vae.']) + config = IntegratedAutoencoderKL.load_config(config_path) + + with using_forge_operations(): + model = IntegratedAutoencoderKL.from_config(config) + + load_state_dict(model, sd) + return model + if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']: + if component_name == 'text_encoder': + sd = try_filter_state_dict(state_dict, ['cond_stage_model.', 'conditioner.embedders.0.']) + elif component_name == 'text_encoder_2': + sd = try_filter_state_dict(state_dict, ['conditioner.embedders.1.']) + else: + raise ValueError(f"Wrong component_name: {component_name}") + + if 'model.text_projection' in sd: + sd = transformers_convert(sd, "model.", "transformer.text_model.", 32) + sd = state_dict_key_replace(sd, {"model.text_projection": "text_projection", + "model.text_projection.weight": "text_projection", + "model.logit_scale": "logit_scale"}) + + config = CLIPTextConfig.from_pretrained(config_path) + + with modeling_utils.no_init_weights(): + with using_forge_operations(): + model = IntegratedCLIP(config) + + load_state_dict(model, sd, ignore_errors=['text_projection', 'logit_scale', + 'transformer.text_model.embeddings.position_ids']) + return model + + print(f'Skipped: {component_name} = {lib_name}.{cls_name}') return None diff --git a/backend/nn/clip.py b/backend/nn/clip.py new file mode 100644 index 00000000..c65f7b2a --- /dev/null +++ b/backend/nn/clip.py @@ -0,0 +1,11 @@ +import torch + +from transformers import CLIPTextModel, CLIPTextConfig + + +class IntegratedCLIP(torch.nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.transformer = CLIPTextModel(config) + self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) diff --git a/backend/state_dict.py b/backend/state_dict.py index 5e86a854..6a5ab7e5 100644 --- a/backend/state_dict.py +++ b/backend/state_dict.py @@ -1,12 +1,95 @@ import torch -def filter_state_dict_with_prefix(sd, prefix): +def load_state_dict(model, sd, ignore_errors=[]): + missing, unexpected = model.load_state_dict(sd, strict=False) + missing = [x for x in missing if x not in ignore_errors] + unexpected = [x for x in unexpected if x not in ignore_errors] + if len(missing) > 0: + print(f'{type(model).__name__} Missing: {missing}') + if len(unexpected) > 0: + print(f'{type(model).__name__} Unexpected: {unexpected}') + return + + +def state_dict_has(sd, prefix): + return any(x.startswith(prefix) for x in sd.keys()) + + +def filter_state_dict_with_prefix(sd, prefix, new_prefix=''): new_sd = {} for k, v in list(sd.items()): if k.startswith(prefix): - new_sd[k[len(prefix):]] = v + new_sd[new_prefix + k[len(prefix):]] = v del sd[k] return new_sd + + +def try_filter_state_dict(sd, prefix_list, new_prefix=''): + for prefix in prefix_list: + if state_dict_has(sd, prefix): + return filter_state_dict_with_prefix(sd, prefix, new_prefix) + return {} + + +def transformers_convert(sd, prefix_from, prefix_to, number): + keys_to_replace = { + "{}positional_embedding": "{}embeddings.position_embedding.weight", + "{}token_embedding.weight": "{}embeddings.token_embedding.weight", + "{}ln_final.weight": "{}final_layer_norm.weight", + "{}ln_final.bias": "{}final_layer_norm.bias", + } + + for k in keys_to_replace: + x = k.format(prefix_from) + if x in sd: + sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x) + + resblock_to_replace = { + "ln_1": "layer_norm1", + "ln_2": "layer_norm2", + "mlp.c_fc": "mlp.fc1", + "mlp.c_proj": "mlp.fc2", + "attn.out_proj": "self_attn.out_proj", + } + + for resblock in range(number): + for x in resblock_to_replace: + for y in ["weight", "bias"]: + k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) + k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) + if k in sd: + sd[k_to] = sd.pop(k) + + for y in ["weight", "bias"]: + k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) + if k_from in sd: + weights = sd.pop(k_from) + shape_from = weights.shape[0] // 3 + for x in range(3): + p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] + k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) + sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return sd + + +def state_dict_key_replace(state_dict, keys_to_replace): + for x in keys_to_replace: + if x in state_dict: + state_dict[keys_to_replace[x]] = state_dict.pop(x) + return state_dict + + +def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): + if filter_keys: + out = {} + else: + out = state_dict + for rp in replace_prefix: + replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) + for x in replace: + w = state_dict.pop(x[0]) + out[x[1]] = w + return out diff --git a/backend/vae.py b/backend/vae.py deleted file mode 100644 index e6649db0..00000000 --- a/backend/vae.py +++ /dev/null @@ -1,14 +0,0 @@ -from backend.state_dict import filter_state_dict_with_prefix -from backend.operations import using_forge_operations -from backend.nn.autoencoder_kl import IntegratedAutoencoderKL - - -def load_vae(state_dict, config_path): - config = IntegratedAutoencoderKL.load_config(config_path) - - with using_forge_operations(): - model = IntegratedAutoencoderKL.from_config(config) - - vae_state_dict = filter_state_dict_with_prefix(state_dict, "first_stage_model.") - model.load_state_dict(vae_state_dict, strict=True) - return model diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index 39e594c0..187f06c2 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -352,7 +352,7 @@ class LoadedModel: elif hasattr(m, "weight"): m.to(self.device) mem_counter += module_size(m) - print(f"[Memory Management] {flag} Loader Disabled for ", m) + print(f"[Memory Management] {flag} Loader Disabled for", type(m).__name__) print(f"[Memory Management] Parameters Loaded to {flag} Stream (MB) = ", real_async_memory / (1024 * 1024)) print(f"[Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024)) diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index fdf6f777..6354f4b3 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -97,22 +97,22 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen return model, clip +from backend.clip import JointCLIP, JointTokenizer + + class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False): + def __init__(self, huggingface_components, no_init=False): if no_init: return - params = target.params.copy() - clip = target.clip - tokenizer = target.tokenizer load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() - params['device'] = offload_device - params['dtype'] = model_management.text_encoder_dtype(load_device) + text_encoder_dtype = model_management.text_encoder_dtype(load_device) - self.cond_stage_model = clip(**(params)) + self.cond_stage_model = JointCLIP(huggingface_components) + self.tokenizer = JointTokenizer(huggingface_components) - self.tokenizer = tokenizer(embedding_directory=embedding_directory) + self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device) self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.layer_idx = None diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index 3c019b94..38eddf08 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -112,13 +112,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c vae = VAE(model=vae) if output_clip: - w = WeightsLoader() - clip_target = model_config.clip_target() - if clip_target is not None: - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - sd = model_config.process_clip_state_dict(sd) - load_model_weights(w, sd) + clip = CLIP(huggingface_components) left_over = sd.keys() if len(left_over) > 0: @@ -177,7 +171,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): embedder = conditioner.embedders[i] typename = type(embedder).__name__ if typename == 'FrozenCLIPEmbedder': # SDXL Clip L - embedder.tokenizer = forge_objects.clip.tokenizer.clip_l.tokenizer + embedder.tokenizer = forge_objects.clip.tokenizer.clip_l embedder.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer model_embeddings = embedder.transformer.text_model.embeddings model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( @@ -186,7 +180,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): conditioner.embedders[i] = embedder text_cond_models.append(embedder) elif typename == 'FrozenOpenCLIPEmbedder2': # SDXL Clip G - embedder.tokenizer = forge_objects.clip.tokenizer.clip_g.tokenizer + embedder.tokenizer = forge_objects.clip.tokenizer.clip_g embedder.transformer = forge_objects.clip.cond_stage_model.clip_g.transformer embedder.text_projection = forge_objects.clip.cond_stage_model.clip_g.text_projection model_embeddings = embedder.transformer.text_model.embeddings @@ -201,14 +195,14 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): else: sd_model.cond_stage_model = conditioner elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': # SD15 Clip - sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l.tokenizer + sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( model_embeddings.token_embedding, sd_hijack.model_hijack) sd_model.cond_stage_model = forge_clip.CLIP_SD_15_L(sd_model.cond_stage_model, sd_hijack.model_hijack) elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip - sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_h.tokenizer + sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_h sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_h.transformer model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(