From c1b23bd49479c890255193c269857892c91376ec Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Tue, 6 Aug 2024 14:23:21 -0700 Subject: [PATCH] rework and speed up model loading --- modules/sd_models.py | 109 ++++------------------------------------- modules/ui_settings.py | 2 +- 2 files changed, 10 insertions(+), 101 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index a000d57f..bb5bc34c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -19,6 +19,7 @@ import numpy as np from backend.loader import forge_loader from backend import memory_management from backend.args import dynamic_args +from backend.utils import load_torch_file model_dir = "Stable-diffusion" @@ -242,45 +243,12 @@ def select_checkpoint(): return checkpoint_info -checkpoint_dict_replacements_sd1 = { - 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', - 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', - 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', -} - -checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format. - 'conditioner.embedders.0.': 'cond_stage_model.', -} - - def transform_checkpoint_dict_key(k, replacements): - for text, replacement in replacements.items(): - if k.startswith(text): - k = replacement + k[len(text):] - - return k + pass def get_state_dict_from_checkpoint(pl_sd): - pl_sd = pl_sd.pop("state_dict", pl_sd) - pl_sd.pop("state_dict", None) - - is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024 - - sd = {} - for k, v in pl_sd.items(): - if is_sd2_turbo: - new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo) - else: - new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1) - - if new_key is not None: - sd[new_key] = v - - pl_sd.clear() - pl_sd.update(sd) - - return pl_sd + pass def read_metadata_from_safetensors(filename): @@ -312,23 +280,7 @@ def read_metadata_from_safetensors(filename): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): - _, extension = os.path.splitext(checkpoint_file) - if extension.lower() == ".safetensors": - device = map_location or shared.weight_load_location or devices.get_optimal_device_name() - - if not shared.opts.disable_mmap_load_safetensors: - pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) - else: - pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read()) - pl_sd = {k: v.to(device) for k, v in pl_sd.items()} - else: - pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) - - if print_global_state and "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - - sd = get_state_dict_from_checkpoint(pl_sd) - return sd + pass def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): @@ -343,25 +295,14 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): return checkpoints_loaded[checkpoint_info] print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") - res = read_state_dict(checkpoint_info.filename) + res = load_torch_file(checkpoint_info.filename) timer.record("load weights from disk") return res class SkipWritingToConfig: - """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight.""" - - skip = False - previous = None - - def __enter__(self): - self.previous = SkipWritingToConfig.skip - SkipWritingToConfig.skip = True - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - SkipWritingToConfig.skip = self.previous + pass def check_fp8(model): @@ -434,12 +375,6 @@ def apply_alpha_schedule_override(sd_model, p=None): sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device) -sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' -sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' -sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' -sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' - - class SdModelData: def __init__(self): self.sd_model = None @@ -479,19 +414,7 @@ model_data = SdModelData() def get_empty_cond(sd_model): - - p = processing.StableDiffusionProcessingTxt2Img() - extra_networks.activate(p, {}) - - if hasattr(sd_model, 'get_learned_conditioning'): - d = sd_model.get_learned_conditioning([""]) - else: - d = sd_model.cond_stage_model([""]) - - if isinstance(d, dict): - d = d['crossattn'] - - return d + pass def send_model_to_cpu(m): @@ -511,22 +434,11 @@ def send_model_to_trash(m): def instantiate_from_config(config, state_dict=None): - constructor = get_obj_from_str(config["target"]) - - params = {**config.get("params", {})} - - if state_dict and "state_dict" in params and params["state_dict"] is None: - params["state_dict"] = state_dict - - return constructor(**params) + pass def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) + pass @torch.no_grad() @@ -568,9 +480,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - if not SkipWritingToConfig.skip: - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title - del state_dict # clean up cache if limit is reached diff --git a/modules/ui_settings.py b/modules/ui_settings.py index bc9c8296..f30fc31e 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -329,7 +329,7 @@ class UiSettings: button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) button_set_checkpoint.click( fn=button_set_checkpoint_change, - _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", inputs=[main_entry.ui_checkpoint, self.dummy_component], outputs=[main_entry.ui_checkpoint, self.text_settings], )