mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 19:21:21 +00:00
rework and speed up model loading
This commit is contained in:
@@ -19,6 +19,7 @@ import numpy as np
|
|||||||
from backend.loader import forge_loader
|
from backend.loader import forge_loader
|
||||||
from backend import memory_management
|
from backend import memory_management
|
||||||
from backend.args import dynamic_args
|
from backend.args import dynamic_args
|
||||||
|
from backend.utils import load_torch_file
|
||||||
|
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
@@ -242,45 +243,12 @@ def select_checkpoint():
|
|||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
checkpoint_dict_replacements_sd1 = {
|
|
||||||
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
|
||||||
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
|
||||||
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
|
||||||
}
|
|
||||||
|
|
||||||
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
|
|
||||||
'conditioner.embedders.0.': 'cond_stage_model.',
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def transform_checkpoint_dict_key(k, replacements):
|
def transform_checkpoint_dict_key(k, replacements):
|
||||||
for text, replacement in replacements.items():
|
pass
|
||||||
if k.startswith(text):
|
|
||||||
k = replacement + k[len(text):]
|
|
||||||
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict_from_checkpoint(pl_sd):
|
def get_state_dict_from_checkpoint(pl_sd):
|
||||||
pl_sd = pl_sd.pop("state_dict", pl_sd)
|
pass
|
||||||
pl_sd.pop("state_dict", None)
|
|
||||||
|
|
||||||
is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
|
|
||||||
|
|
||||||
sd = {}
|
|
||||||
for k, v in pl_sd.items():
|
|
||||||
if is_sd2_turbo:
|
|
||||||
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
|
|
||||||
else:
|
|
||||||
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
|
|
||||||
|
|
||||||
if new_key is not None:
|
|
||||||
sd[new_key] = v
|
|
||||||
|
|
||||||
pl_sd.clear()
|
|
||||||
pl_sd.update(sd)
|
|
||||||
|
|
||||||
return pl_sd
|
|
||||||
|
|
||||||
|
|
||||||
def read_metadata_from_safetensors(filename):
|
def read_metadata_from_safetensors(filename):
|
||||||
@@ -312,23 +280,7 @@ def read_metadata_from_safetensors(filename):
|
|||||||
|
|
||||||
|
|
||||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||||
_, extension = os.path.splitext(checkpoint_file)
|
pass
|
||||||
if extension.lower() == ".safetensors":
|
|
||||||
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
|
||||||
|
|
||||||
if not shared.opts.disable_mmap_load_safetensors:
|
|
||||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
|
||||||
else:
|
|
||||||
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
|
|
||||||
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
|
|
||||||
else:
|
|
||||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
|
||||||
|
|
||||||
if print_global_state and "global_step" in pl_sd:
|
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
|
||||||
|
|
||||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
|
||||||
return sd
|
|
||||||
|
|
||||||
|
|
||||||
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||||
@@ -343,25 +295,14 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
|||||||
return checkpoints_loaded[checkpoint_info]
|
return checkpoints_loaded[checkpoint_info]
|
||||||
|
|
||||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||||
res = read_state_dict(checkpoint_info.filename)
|
res = load_torch_file(checkpoint_info.filename)
|
||||||
timer.record("load weights from disk")
|
timer.record("load weights from disk")
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
class SkipWritingToConfig:
|
class SkipWritingToConfig:
|
||||||
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
|
pass
|
||||||
|
|
||||||
skip = False
|
|
||||||
previous = None
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.previous = SkipWritingToConfig.skip
|
|
||||||
SkipWritingToConfig.skip = True
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
||||||
SkipWritingToConfig.skip = self.previous
|
|
||||||
|
|
||||||
|
|
||||||
def check_fp8(model):
|
def check_fp8(model):
|
||||||
@@ -434,12 +375,6 @@ def apply_alpha_schedule_override(sd_model, p=None):
|
|||||||
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
||||||
|
|
||||||
|
|
||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
|
||||||
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
|
||||||
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
|
||||||
|
|
||||||
|
|
||||||
class SdModelData:
|
class SdModelData:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.sd_model = None
|
self.sd_model = None
|
||||||
@@ -479,19 +414,7 @@ model_data = SdModelData()
|
|||||||
|
|
||||||
|
|
||||||
def get_empty_cond(sd_model):
|
def get_empty_cond(sd_model):
|
||||||
|
pass
|
||||||
p = processing.StableDiffusionProcessingTxt2Img()
|
|
||||||
extra_networks.activate(p, {})
|
|
||||||
|
|
||||||
if hasattr(sd_model, 'get_learned_conditioning'):
|
|
||||||
d = sd_model.get_learned_conditioning([""])
|
|
||||||
else:
|
|
||||||
d = sd_model.cond_stage_model([""])
|
|
||||||
|
|
||||||
if isinstance(d, dict):
|
|
||||||
d = d['crossattn']
|
|
||||||
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def send_model_to_cpu(m):
|
def send_model_to_cpu(m):
|
||||||
@@ -511,22 +434,11 @@ def send_model_to_trash(m):
|
|||||||
|
|
||||||
|
|
||||||
def instantiate_from_config(config, state_dict=None):
|
def instantiate_from_config(config, state_dict=None):
|
||||||
constructor = get_obj_from_str(config["target"])
|
pass
|
||||||
|
|
||||||
params = {**config.get("params", {})}
|
|
||||||
|
|
||||||
if state_dict and "state_dict" in params and params["state_dict"] is None:
|
|
||||||
params["state_dict"] = state_dict
|
|
||||||
|
|
||||||
return constructor(**params)
|
|
||||||
|
|
||||||
|
|
||||||
def get_obj_from_str(string, reload=False):
|
def get_obj_from_str(string, reload=False):
|
||||||
module, cls = string.rsplit(".", 1)
|
pass
|
||||||
if reload:
|
|
||||||
module_imp = importlib.import_module(module)
|
|
||||||
importlib.reload(module_imp)
|
|
||||||
return getattr(importlib.import_module(module, package=None), cls)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -568,9 +480,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
timer.record("calculate hash")
|
timer.record("calculate hash")
|
||||||
|
|
||||||
if not SkipWritingToConfig.skip:
|
|
||||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
|
||||||
|
|
||||||
del state_dict
|
del state_dict
|
||||||
|
|
||||||
# clean up cache if limit is reached
|
# clean up cache if limit is reached
|
||||||
|
|||||||
@@ -329,7 +329,7 @@ class UiSettings:
|
|||||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||||
button_set_checkpoint.click(
|
button_set_checkpoint.click(
|
||||||
fn=button_set_checkpoint_change,
|
fn=button_set_checkpoint_change,
|
||||||
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
||||||
inputs=[main_entry.ui_checkpoint, self.dummy_component],
|
inputs=[main_entry.ui_checkpoint, self.dummy_component],
|
||||||
outputs=[main_entry.ui_checkpoint, self.text_settings],
|
outputs=[main_entry.ui_checkpoint, self.text_settings],
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user