mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
@@ -45,7 +45,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()):
|
||||
model = IntegratedAutoencoderKL.from_config(config)
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
load_state_dict(model, state_dict, ignore_start='loss.')
|
||||
return model
|
||||
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
|
||||
from transformers import CLIPTextConfig, CLIPTextModel
|
||||
@@ -113,13 +113,16 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
return None
|
||||
|
||||
|
||||
def split_state_dict(sd):
|
||||
def split_state_dict(sd, sd_vae=None):
|
||||
guess = huggingface_guess.guess(sd)
|
||||
guess.clip_target = guess.clip_target(sd)
|
||||
|
||||
if sd_vae is not None:
|
||||
print(f'Using external VAE state dict: {len(sd_vae)}')
|
||||
|
||||
state_dict = {
|
||||
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
|
||||
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
|
||||
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) if sd_vae is None else sd_vae
|
||||
}
|
||||
|
||||
sd = guess.process_clip_state_dict(sd)
|
||||
@@ -138,8 +141,8 @@ def split_state_dict(sd):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forge_loader(sd):
|
||||
state_dicts, estimated_config = split_state_dict(sd)
|
||||
def forge_loader(sd, sd_vae=None):
|
||||
state_dicts, estimated_config = split_state_dict(sd, sd_vae=sd_vae)
|
||||
repo_name = estimated_config.huggingface_repo
|
||||
|
||||
local_path = os.path.join(dir_path, 'huggingface', repo_name)
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
def load_state_dict(model, sd, ignore_errors=[], log_name=None):
|
||||
def load_state_dict(model, sd, ignore_errors=[], log_name=None, ignore_start=None):
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||||
missing = [x for x in missing if x not in ignore_errors]
|
||||
unexpected = [x for x in unexpected if x not in ignore_errors]
|
||||
|
||||
if isinstance(ignore_start, str):
|
||||
missing = [x for x in missing if not x.startswith(ignore_start)]
|
||||
unexpected = [x for x in unexpected if not x.startswith(ignore_start)]
|
||||
|
||||
log_name = log_name or type(model).__name__
|
||||
if len(missing) > 0:
|
||||
print(f'{log_name} Missing: {missing}')
|
||||
|
||||
@@ -196,7 +196,6 @@ def img2img_function(id_task: str, request: gr.Request, mode: int, prompt: str,
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
|
||||
p = StableDiffusionProcessingImg2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
|
||||
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
||||
prompt=prompt,
|
||||
|
||||
@@ -30,7 +30,7 @@ import modules.sd_vae as sd_vae
|
||||
|
||||
from einops import repeat, rearrange
|
||||
from blendmodes.blend import blendLayers, BlendType
|
||||
from modules.sd_models import apply_token_merging
|
||||
from modules.sd_models import apply_token_merging, forge_model_reload
|
||||
from modules_forge.utils import apply_circular_forge
|
||||
|
||||
|
||||
@@ -774,41 +774,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
|
||||
|
||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
forge_model_reload()
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.before_process(p)
|
||||
|
||||
stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
|
||||
# backwards compatibility, fix sampler and scheduler if invalid
|
||||
sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
|
||||
|
||||
try:
|
||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||
# and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
|
||||
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
||||
p.override_settings.pop('sd_model_checkpoint', None)
|
||||
sd_models.reload_model_weights()
|
||||
|
||||
for k, v in p.override_settings.items():
|
||||
opts.set(k, v, is_api=True, run_callbacks=False)
|
||||
|
||||
if k == 'sd_model_checkpoint':
|
||||
sd_models.reload_model_weights()
|
||||
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights()
|
||||
|
||||
# backwards compatibility, fix sampler and scheduler if invalid
|
||||
sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
|
||||
|
||||
with profiling.Profiler():
|
||||
res = process_images_inner(p)
|
||||
|
||||
finally:
|
||||
# restore opts to original state
|
||||
if p.override_settings_restore_afterwards:
|
||||
for k, v in stored_opts.items():
|
||||
setattr(opts, k, v)
|
||||
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights()
|
||||
with profiling.Profiler():
|
||||
res = process_images_inner(p)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -132,6 +132,12 @@ class CheckpointInfo:
|
||||
|
||||
return self.shorthash
|
||||
|
||||
def __str__(self):
|
||||
return str(dict(filename=self.filename, hash=self.hash))
|
||||
|
||||
def __repr__(self):
|
||||
return str(dict(filename=self.filename, hash=self.hash))
|
||||
|
||||
|
||||
# try:
|
||||
# # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
@@ -379,8 +385,8 @@ def apply_alpha_schedule_override(sd_model, p=None):
|
||||
class SdModelData:
|
||||
def __init__(self):
|
||||
self.sd_model = None
|
||||
self.loaded_sd_models = []
|
||||
self.was_loaded_at_least_once = False
|
||||
self.forge_loading_parameters = {}
|
||||
self.forge_hash = ''
|
||||
|
||||
def get_sd_model(self):
|
||||
if self.sd_model is None:
|
||||
@@ -388,12 +394,8 @@ class SdModelData:
|
||||
|
||||
return self.sd_model
|
||||
|
||||
def set_sd_model(self, v, already_loaded=False):
|
||||
def set_sd_model(self, v):
|
||||
self.sd_model = v
|
||||
if already_loaded:
|
||||
sd_vae.base_vae = getattr(v, "base_vae", None)
|
||||
sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
|
||||
sd_vae.checkpoint_info = v.sd_checkpoint_info
|
||||
|
||||
|
||||
model_data = SdModelData()
|
||||
@@ -461,28 +463,45 @@ def apply_token_merging(sd_model, token_merging_ratio):
|
||||
|
||||
@torch.no_grad()
|
||||
def forge_model_reload():
|
||||
checkpoint_info = select_checkpoint()
|
||||
current_hash = str(model_data.forge_loading_parameters)
|
||||
|
||||
if model_data.forge_hash == current_hash:
|
||||
return model_data.sd_model
|
||||
|
||||
print('Loading Model: ' + str(model_data.forge_loading_parameters))
|
||||
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
model_data.sd_model = None
|
||||
model_data.loaded_sd_models = []
|
||||
memory_management.unload_all_models()
|
||||
memory_management.soft_empty_cache()
|
||||
gc.collect()
|
||||
|
||||
timer.record("unload existing model")
|
||||
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
checkpoint_info = model_data.forge_loading_parameters['checkpoint_info']
|
||||
state_dict = load_torch_file(checkpoint_info.filename)
|
||||
timer.record("load state dict")
|
||||
|
||||
state_dict_vae = model_data.forge_loading_parameters.get('vae_filename', None)
|
||||
|
||||
if state_dict_vae is not None:
|
||||
state_dict_vae = load_torch_file(state_dict_vae)
|
||||
|
||||
timer.record("load vae state dict")
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
|
||||
timer.record("cache state dict")
|
||||
|
||||
dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
|
||||
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
|
||||
dynamic_args['emphasis_name'] = opts.emphasis
|
||||
sd_model = forge_loader(state_dict)
|
||||
sd_model = forge_loader(state_dict, sd_vae=state_dict_vae)
|
||||
del state_dict
|
||||
timer.record("forge model load")
|
||||
|
||||
sd_model.extra_generation_params = {}
|
||||
@@ -492,22 +511,13 @@ def forge_model_reload():
|
||||
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
del state_dict
|
||||
|
||||
# clean up cache if limit is reached
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False)
|
||||
|
||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
||||
sd_vae.load_vae(sd_model, vae_file, vae_source)
|
||||
timer.record("load VAE")
|
||||
|
||||
model_data.set_sd_model(sd_model)
|
||||
model_data.was_loaded_at_least_once = True
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
@@ -515,4 +525,6 @@ def forge_model_reload():
|
||||
|
||||
print(f"Model loaded in {timer.summary()}.")
|
||||
|
||||
model_data.forge_hash = current_hash
|
||||
|
||||
return sd_model
|
||||
|
||||
@@ -187,87 +187,24 @@ def resolve_vae(checkpoint_file) -> VaeResolution:
|
||||
|
||||
|
||||
def load_vae_dict(filename, map_location):
|
||||
return load_torch_file(filename)
|
||||
pass
|
||||
|
||||
|
||||
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||
global vae_dict, base_vae, loaded_vae_file
|
||||
# save_settings = False
|
||||
|
||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||
|
||||
if vae_file:
|
||||
if cache_enabled and vae_file in checkpoints_loaded:
|
||||
# use vae checkpoint cache
|
||||
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
|
||||
store_base_vae(model)
|
||||
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||
else:
|
||||
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
||||
store_base_vae(model)
|
||||
|
||||
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
|
||||
if cache_enabled:
|
||||
# cache newly loaded vae
|
||||
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
||||
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
# If vae used is not in dict, update it
|
||||
# It will be removed on refresh though
|
||||
vae_opt = get_filename(vae_file)
|
||||
if vae_opt not in vae_dict:
|
||||
vae_dict[vae_opt] = vae_file
|
||||
|
||||
elif loaded_vae_file:
|
||||
restore_base_vae(model)
|
||||
|
||||
loaded_vae_file = vae_file
|
||||
model.base_vae = base_vae
|
||||
model.loaded_vae_file = loaded_vae_file
|
||||
raise NotImplementedError('Forge does not use this!')
|
||||
|
||||
|
||||
# don't call this from outside
|
||||
def _load_vae_dict(model, vae_dict_1):
|
||||
model.first_stage_model.load_state_dict(vae_dict_1, strict=False)
|
||||
pass
|
||||
|
||||
|
||||
def clear_loaded_vae():
|
||||
global loaded_vae_file
|
||||
loaded_vae_file = None
|
||||
pass
|
||||
|
||||
|
||||
unspecified = object()
|
||||
|
||||
|
||||
def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
checkpoint_info = sd_model.sd_checkpoint_info
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
|
||||
if vae_file == unspecified:
|
||||
vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
|
||||
else:
|
||||
vae_source = "from function argument"
|
||||
|
||||
if loaded_vae_file == vae_file:
|
||||
return
|
||||
|
||||
# sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_vae(sd_model, vae_file, vae_source)
|
||||
|
||||
# sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
print("VAE weights loaded.")
|
||||
return sd_model
|
||||
raise NotImplementedError('Forge does not use this!')
|
||||
|
||||
@@ -19,7 +19,6 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
|
||||
enable_hr = True
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
||||
prompt=prompt,
|
||||
|
||||
@@ -3,7 +3,6 @@ import gradio as gr
|
||||
|
||||
from modules import shared_items, shared, ui_common, sd_models
|
||||
from modules import sd_vae as sd_vae_module
|
||||
from modules_forge import main_thread
|
||||
from backend import args as backend_args
|
||||
|
||||
|
||||
@@ -59,7 +58,7 @@ def make_checkpoint_manager_ui():
|
||||
ui_common.create_refresh_button(ui_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae")
|
||||
|
||||
ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion in FP8", value=shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys()))
|
||||
bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=lambda: main_thread.async_run(model_load_entry))
|
||||
bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=refresh_model_loading_parameters)
|
||||
|
||||
ui_clip_skip = gr.Slider(label="Clip skip", value=shared.opts.CLIP_stop_at_last_layers, **{"minimum": 1, "maximum": 12, "step": 1})
|
||||
bind_to_opts(ui_clip_skip, 'CLIP_stop_at_last_layers', save=False)
|
||||
@@ -67,12 +66,18 @@ def make_checkpoint_manager_ui():
|
||||
return
|
||||
|
||||
|
||||
def model_load_entry():
|
||||
backend_args.dynamic_args.update(dict(
|
||||
forge_unet_storage_dtype=forge_unet_storage_dtype_options[shared.opts.forge_unet_storage_dtype]
|
||||
))
|
||||
def refresh_model_loading_parameters():
|
||||
from modules.sd_models import select_checkpoint, model_data
|
||||
|
||||
checkpoint_info = select_checkpoint()
|
||||
vae_resolution = sd_vae_module.resolve_vae(checkpoint_info.filename)
|
||||
|
||||
model_data.forge_loading_parameters = dict(
|
||||
checkpoint_info=checkpoint_info,
|
||||
vae_filename=vae_resolution.vae,
|
||||
unet_storage_dtype=forge_unet_storage_dtype_options[shared.opts.forge_unet_storage_dtype]
|
||||
)
|
||||
|
||||
sd_models.forge_model_reload()
|
||||
return
|
||||
|
||||
|
||||
@@ -81,21 +86,22 @@ def checkpoint_change(ckpt_name):
|
||||
shared.opts.set('sd_model_checkpoint', ckpt_name)
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
model_load_entry()
|
||||
refresh_model_loading_parameters()
|
||||
return
|
||||
|
||||
|
||||
def vae_change(vae_name):
|
||||
print(f'VAE Selected: {vae_name}')
|
||||
shared.opts.set('sd_vae', vae_name)
|
||||
sd_vae_module.reload_vae_weights()
|
||||
|
||||
refresh_model_loading_parameters()
|
||||
return
|
||||
|
||||
|
||||
def forge_main_entry():
|
||||
ui_checkpoint.change(lambda x: main_thread.async_run(checkpoint_change, x), inputs=[ui_checkpoint], show_progress=False)
|
||||
ui_vae.change(lambda x: main_thread.async_run(vae_change, x), inputs=[ui_vae], show_progress=False)
|
||||
ui_checkpoint.change(checkpoint_change, inputs=[ui_checkpoint], show_progress=False)
|
||||
ui_vae.change(vae_change, inputs=[ui_vae], show_progress=False)
|
||||
|
||||
# Load Model
|
||||
main_thread.async_run(model_load_entry)
|
||||
refresh_model_loading_parameters()
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user