Gradio 4 + WebUI 1.10

This commit is contained in:
layerdiffusion
2024-07-26 08:51:34 -07:00
parent e95333c556
commit e26abf87ec
201 changed files with 7562 additions and 4834 deletions

View File

@@ -1,19 +1,18 @@
import collections
import os.path
import importlib
import os
import sys
import threading
import enum
import torch
import re
import safetensors.torch
from omegaconf import OmegaConf, ListConfig
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
import gc
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
import numpy as np
@@ -33,6 +32,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded = collections.OrderedDict()
class ModelType(enum.Enum):
SD1 = 1
SD2 = 2
SDXL = 3
SSD = 4
SD3 = 5
def replace_key(d, key, new_key, value):
keys = list(d.keys())
@@ -155,6 +162,7 @@ def list_models():
cmd_ckpt = shared.cmd_opts.ckpt
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
model_url = None
expected_sha256 = None
else:
model_url = "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors"
@@ -286,17 +294,21 @@ def read_metadata_from_safetensors(filename):
json_start = file.read(2)
assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
json_data = json_start + file.read(metadata_len-2)
json_obj = json.loads(json_data)
res = {}
for k, v in json_obj.get("__metadata__", {}).items():
res[k] = v
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
except Exception:
pass
try:
json_data = json_start + file.read(metadata_len-2)
json_obj = json.loads(json_data)
for k, v in json_obj.get("__metadata__", {}).items():
res[k] = v
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
except Exception:
pass
except Exception:
errors.report(f"Error reading metadata from file: {filename}", exc_info=True)
return res
@@ -368,42 +380,39 @@ def check_fp8(model):
return enable_fp8
def set_model_type(model, state_dict):
model.is_sd1 = False
model.is_sd2 = False
model.is_sdxl = False
model.is_ssd = False
model.is_sd3 = False
if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
model.is_sd3 = True
model.model_type = ModelType.SD3
elif hasattr(model, 'conditioner'):
model.is_sdxl = True
if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
model.is_ssd = True
model.model_type = ModelType.SSD
else:
model.model_type = ModelType.SDXL
elif hasattr(model.cond_stage_model, 'model'):
model.is_sd2 = True
model.model_type = ModelType.SD2
else:
model.is_sd1 = True
model.model_type = ModelType.SD1
def set_model_fields(model):
if not hasattr(model, 'latent_channels'):
model.latent_channels = 4
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
model.load_state_dict(state_dict, strict=False)
timer.record("apply weights to model")
del state_dict
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
if hasattr(model, 'logvar'):
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
sd_vae.load_vae(model, vae_file, vae_source)
timer.record("load VAE")
return
def enable_midas_autodownload():
@@ -438,7 +447,7 @@ def enable_midas_autodownload():
path = midas.api.ISL_PATHS[model_type]
if not os.path.exists(path):
if not os.path.exists(midas_path):
mkdir(midas_path)
os.mkdir(midas_path)
print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path)
@@ -463,25 +472,76 @@ def patch_given_betas():
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
def repair_config(sd_config):
def repair_config(sd_config, state_dict=None):
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
if hasattr(sd_config.model.params, 'unet_config'):
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling:
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
if hasattr(sd_config.model.params, 'first_stage_config'):
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
# Do not use checkpoint for inference.
# This helps prevent extra performance overhead on checking parameters.
# The perf overhead is about 100ms/it on 4090 for SDXL.
if hasattr(sd_config.model.params, "network_config"):
sd_config.model.params.network_config.params.use_checkpoint = False
if hasattr(sd_config.model.params, "unet_config"):
sd_config.model.params.unet_config.params.use_checkpoint = False
def rescale_zero_terminal_snr_abar(alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return alphas_bar
def apply_alpha_schedule_override(sd_model, p=None):
"""
Applies an override to the alpha schedule of the model according to settings.
- downcasts the alpha schedule to half precision
- rescales the alpha schedule to have zero terminal SNR
"""
if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
return
sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
if opts.use_downcasted_alpha_bar:
if p is not None:
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
if opts.sd_noise_schedule == "Zero Terminal SNR":
if p is not None:
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
@@ -532,11 +592,15 @@ def get_empty_cond(sd_model):
p = processing.StableDiffusionProcessingTxt2Img()
extra_networks.activate(p, {})
if hasattr(sd_model, 'conditioner'):
if hasattr(sd_model, 'get_learned_conditioning'):
d = sd_model.get_learned_conditioning([""])
return d['crossattn']
else:
return sd_model.cond_stage_model([""])
d = sd_model.cond_stage_model([""])
if isinstance(d, dict):
d = d['crossattn']
return d
def send_model_to_cpu(m):
@@ -555,6 +619,25 @@ def send_model_to_trash(m):
pass
def instantiate_from_config(config, state_dict=None):
constructor = get_obj_from_str(config["target"])
params = {**config.get("params", {})}
if state_dict and "state_dict" in params and params["state_dict"] is None:
params["state_dict"] = state_dict
return constructor(**params)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -585,6 +668,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = forge_loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict)
sd_model.filename = checkpoint_info.filename
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
del state_dict
# clean up cache if limit is reached