mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 17:09:49 +00:00
rework and speed up model loading
This commit is contained in:
@@ -19,6 +19,7 @@ import numpy as np
|
||||
from backend.loader import forge_loader
|
||||
from backend import memory_management
|
||||
from backend.args import dynamic_args
|
||||
from backend.utils import load_torch_file
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
@@ -242,45 +243,12 @@ def select_checkpoint():
|
||||
return checkpoint_info
|
||||
|
||||
|
||||
checkpoint_dict_replacements_sd1 = {
|
||||
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
||||
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
||||
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
||||
}
|
||||
|
||||
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
|
||||
'conditioner.embedders.0.': 'cond_stage_model.',
|
||||
}
|
||||
|
||||
|
||||
def transform_checkpoint_dict_key(k, replacements):
|
||||
for text, replacement in replacements.items():
|
||||
if k.startswith(text):
|
||||
k = replacement + k[len(text):]
|
||||
|
||||
return k
|
||||
pass
|
||||
|
||||
|
||||
def get_state_dict_from_checkpoint(pl_sd):
|
||||
pl_sd = pl_sd.pop("state_dict", pl_sd)
|
||||
pl_sd.pop("state_dict", None)
|
||||
|
||||
is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
|
||||
|
||||
sd = {}
|
||||
for k, v in pl_sd.items():
|
||||
if is_sd2_turbo:
|
||||
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
|
||||
else:
|
||||
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
|
||||
|
||||
if new_key is not None:
|
||||
sd[new_key] = v
|
||||
|
||||
pl_sd.clear()
|
||||
pl_sd.update(sd)
|
||||
|
||||
return pl_sd
|
||||
pass
|
||||
|
||||
|
||||
def read_metadata_from_safetensors(filename):
|
||||
@@ -312,23 +280,7 @@ def read_metadata_from_safetensors(filename):
|
||||
|
||||
|
||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||
|
||||
if not shared.opts.disable_mmap_load_safetensors:
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
|
||||
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
||||
if print_global_state and "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||
return sd
|
||||
pass
|
||||
|
||||
|
||||
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
@@ -343,25 +295,14 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
return checkpoints_loaded[checkpoint_info]
|
||||
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
res = read_state_dict(checkpoint_info.filename)
|
||||
res = load_torch_file(checkpoint_info.filename)
|
||||
timer.record("load weights from disk")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class SkipWritingToConfig:
|
||||
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
|
||||
|
||||
skip = False
|
||||
previous = None
|
||||
|
||||
def __enter__(self):
|
||||
self.previous = SkipWritingToConfig.skip
|
||||
SkipWritingToConfig.skip = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
SkipWritingToConfig.skip = self.previous
|
||||
pass
|
||||
|
||||
|
||||
def check_fp8(model):
|
||||
@@ -434,12 +375,6 @@ def apply_alpha_schedule_override(sd_model, p=None):
|
||||
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
||||
|
||||
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
||||
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
||||
|
||||
|
||||
class SdModelData:
|
||||
def __init__(self):
|
||||
self.sd_model = None
|
||||
@@ -479,19 +414,7 @@ model_data = SdModelData()
|
||||
|
||||
|
||||
def get_empty_cond(sd_model):
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img()
|
||||
extra_networks.activate(p, {})
|
||||
|
||||
if hasattr(sd_model, 'get_learned_conditioning'):
|
||||
d = sd_model.get_learned_conditioning([""])
|
||||
else:
|
||||
d = sd_model.cond_stage_model([""])
|
||||
|
||||
if isinstance(d, dict):
|
||||
d = d['crossattn']
|
||||
|
||||
return d
|
||||
pass
|
||||
|
||||
|
||||
def send_model_to_cpu(m):
|
||||
@@ -511,22 +434,11 @@ def send_model_to_trash(m):
|
||||
|
||||
|
||||
def instantiate_from_config(config, state_dict=None):
|
||||
constructor = get_obj_from_str(config["target"])
|
||||
|
||||
params = {**config.get("params", {})}
|
||||
|
||||
if state_dict and "state_dict" in params and params["state_dict"] is None:
|
||||
params["state_dict"] = state_dict
|
||||
|
||||
return constructor(**params)
|
||||
pass
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
pass
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -568,9 +480,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
del state_dict
|
||||
|
||||
# clean up cache if limit is reached
|
||||
|
||||
@@ -329,7 +329,7 @@ class UiSettings:
|
||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||
button_set_checkpoint.click(
|
||||
fn=button_set_checkpoint_change,
|
||||
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
||||
js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
||||
inputs=[main_entry.ui_checkpoint, self.dummy_component],
|
||||
outputs=[main_entry.ui_checkpoint, self.text_settings],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user