mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-21 14:59:05 +00:00
Gradio 4 + WebUI 1.10
This commit is contained in:
@@ -1,19 +1,18 @@
|
||||
import collections
|
||||
import os.path
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
from omegaconf import OmegaConf, ListConfig
|
||||
from os import mkdir
|
||||
from urllib import request
|
||||
import ldm.modules.midas as midas
|
||||
import gc
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules.timer import Timer
|
||||
import numpy as np
|
||||
@@ -33,6 +32,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
SD1 = 1
|
||||
SD2 = 2
|
||||
SDXL = 3
|
||||
SSD = 4
|
||||
SD3 = 5
|
||||
|
||||
|
||||
def replace_key(d, key, new_key, value):
|
||||
keys = list(d.keys())
|
||||
|
||||
@@ -155,6 +162,7 @@ def list_models():
|
||||
cmd_ckpt = shared.cmd_opts.ckpt
|
||||
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
|
||||
model_url = None
|
||||
expected_sha256 = None
|
||||
else:
|
||||
model_url = "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors"
|
||||
|
||||
@@ -286,17 +294,21 @@ def read_metadata_from_safetensors(filename):
|
||||
json_start = file.read(2)
|
||||
|
||||
assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
|
||||
json_data = json_start + file.read(metadata_len-2)
|
||||
json_obj = json.loads(json_data)
|
||||
|
||||
res = {}
|
||||
for k, v in json_obj.get("__metadata__", {}).items():
|
||||
res[k] = v
|
||||
if isinstance(v, str) and v[0:1] == '{':
|
||||
try:
|
||||
res[k] = json.loads(v)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
json_data = json_start + file.read(metadata_len-2)
|
||||
json_obj = json.loads(json_data)
|
||||
for k, v in json_obj.get("__metadata__", {}).items():
|
||||
res[k] = v
|
||||
if isinstance(v, str) and v[0:1] == '{':
|
||||
try:
|
||||
res[k] = json.loads(v)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
errors.report(f"Error reading metadata from file: {filename}", exc_info=True)
|
||||
|
||||
return res
|
||||
|
||||
@@ -368,42 +380,39 @@ def check_fp8(model):
|
||||
return enable_fp8
|
||||
|
||||
|
||||
def set_model_type(model, state_dict):
|
||||
model.is_sd1 = False
|
||||
model.is_sd2 = False
|
||||
model.is_sdxl = False
|
||||
model.is_ssd = False
|
||||
model.is_sd3 = False
|
||||
|
||||
if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
|
||||
model.is_sd3 = True
|
||||
model.model_type = ModelType.SD3
|
||||
elif hasattr(model, 'conditioner'):
|
||||
model.is_sdxl = True
|
||||
|
||||
if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
|
||||
model.is_ssd = True
|
||||
model.model_type = ModelType.SSD
|
||||
else:
|
||||
model.model_type = ModelType.SDXL
|
||||
elif hasattr(model.cond_stage_model, 'model'):
|
||||
model.is_sd2 = True
|
||||
model.model_type = ModelType.SD2
|
||||
else:
|
||||
model.is_sd1 = True
|
||||
model.model_type = ModelType.SD1
|
||||
|
||||
|
||||
def set_model_fields(model):
|
||||
if not hasattr(model, 'latent_channels'):
|
||||
model.latent_channels = 4
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
timer.record("apply weights to model")
|
||||
|
||||
del state_dict
|
||||
|
||||
# clean up cache if limit is reached
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False)
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpoint = checkpoint_info.filename
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||
|
||||
if hasattr(model, 'logvar'):
|
||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
||||
sd_vae.load_vae(model, vae_file, vae_source)
|
||||
timer.record("load VAE")
|
||||
return
|
||||
|
||||
|
||||
def enable_midas_autodownload():
|
||||
@@ -438,7 +447,7 @@ def enable_midas_autodownload():
|
||||
path = midas.api.ISL_PATHS[model_type]
|
||||
if not os.path.exists(path):
|
||||
if not os.path.exists(midas_path):
|
||||
mkdir(midas_path)
|
||||
os.mkdir(midas_path)
|
||||
|
||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||
request.urlretrieve(midas_urls[model_type], path)
|
||||
@@ -463,25 +472,76 @@ def patch_given_betas():
|
||||
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
|
||||
|
||||
|
||||
def repair_config(sd_config):
|
||||
|
||||
def repair_config(sd_config, state_dict=None):
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
if hasattr(sd_config.model.params, 'unet_config'):
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling:
|
||||
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||
if hasattr(sd_config.model.params, 'first_stage_config'):
|
||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||
|
||||
# For UnCLIP-L, override the hardcoded karlo directory
|
||||
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
|
||||
karlo_path = os.path.join(paths.models_path, 'karlo')
|
||||
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
||||
|
||||
# Do not use checkpoint for inference.
|
||||
# This helps prevent extra performance overhead on checking parameters.
|
||||
# The perf overhead is about 100ms/it on 4090 for SDXL.
|
||||
if hasattr(sd_config.model.params, "network_config"):
|
||||
sd_config.model.params.network_config.params.use_checkpoint = False
|
||||
if hasattr(sd_config.model.params, "unet_config"):
|
||||
sd_config.model.params.unet_config.params.use_checkpoint = False
|
||||
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return alphas_bar
|
||||
|
||||
|
||||
def apply_alpha_schedule_override(sd_model, p=None):
|
||||
"""
|
||||
Applies an override to the alpha schedule of the model according to settings.
|
||||
- downcasts the alpha schedule to half precision
|
||||
- rescales the alpha schedule to have zero terminal SNR
|
||||
"""
|
||||
|
||||
if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
|
||||
return
|
||||
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
|
||||
|
||||
if opts.use_downcasted_alpha_bar:
|
||||
if p is not None:
|
||||
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
|
||||
|
||||
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
||||
if p is not None:
|
||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
||||
|
||||
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
@@ -532,11 +592,15 @@ def get_empty_cond(sd_model):
|
||||
p = processing.StableDiffusionProcessingTxt2Img()
|
||||
extra_networks.activate(p, {})
|
||||
|
||||
if hasattr(sd_model, 'conditioner'):
|
||||
if hasattr(sd_model, 'get_learned_conditioning'):
|
||||
d = sd_model.get_learned_conditioning([""])
|
||||
return d['crossattn']
|
||||
else:
|
||||
return sd_model.cond_stage_model([""])
|
||||
d = sd_model.cond_stage_model([""])
|
||||
|
||||
if isinstance(d, dict):
|
||||
d = d['crossattn']
|
||||
|
||||
return d
|
||||
|
||||
|
||||
def send_model_to_cpu(m):
|
||||
@@ -555,6 +619,25 @@ def send_model_to_trash(m):
|
||||
pass
|
||||
|
||||
|
||||
def instantiate_from_config(config, state_dict=None):
|
||||
constructor = get_obj_from_str(config["target"])
|
||||
|
||||
params = {**config.get("params", {})}
|
||||
|
||||
if state_dict and "state_dict" in params and params["state_dict"] is None:
|
||||
params["state_dict"] = state_dict
|
||||
|
||||
return constructor(**params)
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
from modules import sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
@@ -585,6 +668,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
sd_model = forge_loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict)
|
||||
sd_model.filename = checkpoint_info.filename
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
del state_dict
|
||||
|
||||
# clean up cache if limit is reached
|
||||
|
||||
Reference in New Issue
Block a user