rework and speed up model loading

This commit is contained in:
layerdiffusion
2024-08-06 14:23:21 -07:00
parent 4122e7f239
commit c1b23bd494
2 changed files with 10 additions and 101 deletions

View File

@@ -19,6 +19,7 @@ import numpy as np
from backend.loader import forge_loader
from backend import memory_management
from backend.args import dynamic_args
from backend.utils import load_torch_file
model_dir = "Stable-diffusion"
@@ -242,45 +243,12 @@ def select_checkpoint():
return checkpoint_info
checkpoint_dict_replacements_sd1 = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
'conditioner.embedders.0.': 'cond_stage_model.',
}
def transform_checkpoint_dict_key(k, replacements):
for text, replacement in replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
return k
pass
def get_state_dict_from_checkpoint(pl_sd):
pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd.pop("state_dict", None)
is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
sd = {}
for k, v in pl_sd.items():
if is_sd2_turbo:
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
else:
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
if new_key is not None:
sd[new_key] = v
pl_sd.clear()
pl_sd.update(sd)
return pl_sd
pass
def read_metadata_from_safetensors(filename):
@@ -312,23 +280,7 @@ def read_metadata_from_safetensors(filename):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
if not shared.opts.disable_mmap_load_safetensors:
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
if print_global_state and "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
return sd
pass
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
@@ -343,25 +295,14 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
return checkpoints_loaded[checkpoint_info]
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
res = read_state_dict(checkpoint_info.filename)
res = load_torch_file(checkpoint_info.filename)
timer.record("load weights from disk")
return res
class SkipWritingToConfig:
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
skip = False
previous = None
def __enter__(self):
self.previous = SkipWritingToConfig.skip
SkipWritingToConfig.skip = True
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
SkipWritingToConfig.skip = self.previous
pass
def check_fp8(model):
@@ -434,12 +375,6 @@ def apply_alpha_schedule_override(sd_model, p=None):
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData:
def __init__(self):
self.sd_model = None
@@ -479,19 +414,7 @@ model_data = SdModelData()
def get_empty_cond(sd_model):
p = processing.StableDiffusionProcessingTxt2Img()
extra_networks.activate(p, {})
if hasattr(sd_model, 'get_learned_conditioning'):
d = sd_model.get_learned_conditioning([""])
else:
d = sd_model.cond_stage_model([""])
if isinstance(d, dict):
d = d['crossattn']
return d
pass
def send_model_to_cpu(m):
@@ -511,22 +434,11 @@ def send_model_to_trash(m):
def instantiate_from_config(config, state_dict=None):
constructor = get_obj_from_str(config["target"])
params = {**config.get("params", {})}
if state_dict and "state_dict" in params and params["state_dict"] is None:
params["state_dict"] = state_dict
return constructor(**params)
pass
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
pass
@torch.no_grad()
@@ -568,9 +480,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
del state_dict
# clean up cache if limit is reached

View File

@@ -329,7 +329,7 @@ class UiSettings:
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click(
fn=button_set_checkpoint_change,
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
inputs=[main_entry.ui_checkpoint, self.dummy_component],
outputs=[main_entry.ui_checkpoint, self.text_settings],
)