diff --git a/backend/loader.py b/backend/loader.py index 4d928c38..606a33e6 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -45,7 +45,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()): model = IntegratedAutoencoderKL.from_config(config) - load_state_dict(model, state_dict) + load_state_dict(model, state_dict, ignore_start='loss.') return model if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']: from transformers import CLIPTextConfig, CLIPTextModel @@ -113,13 +113,16 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p return None -def split_state_dict(sd): +def split_state_dict(sd, sd_vae=None): guess = huggingface_guess.guess(sd) guess.clip_target = guess.clip_target(sd) + if sd_vae is not None: + print(f'Using external VAE state dict: {len(sd_vae)}') + state_dict = { guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix), - guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) + guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) if sd_vae is None else sd_vae } sd = guess.process_clip_state_dict(sd) @@ -138,8 +141,8 @@ def split_state_dict(sd): @torch.no_grad() -def forge_loader(sd): - state_dicts, estimated_config = split_state_dict(sd) +def forge_loader(sd, sd_vae=None): + state_dicts, estimated_config = split_state_dict(sd, sd_vae=sd_vae) repo_name = estimated_config.huggingface_repo local_path = os.path.join(dir_path, 'huggingface', repo_name) diff --git a/backend/state_dict.py b/backend/state_dict.py index ef3bdc8b..317bfce5 100644 --- a/backend/state_dict.py +++ b/backend/state_dict.py @@ -1,10 +1,15 @@ import torch -def load_state_dict(model, sd, ignore_errors=[], log_name=None): +def load_state_dict(model, sd, ignore_errors=[], log_name=None, ignore_start=None): missing, unexpected = model.load_state_dict(sd, strict=False) missing = [x for x in missing if x not in ignore_errors] unexpected = [x for x in unexpected if x not in ignore_errors] + + if isinstance(ignore_start, str): + missing = [x for x in missing if not x.startswith(ignore_start)] + unexpected = [x for x in unexpected if not x.startswith(ignore_start)] + log_name = log_name or type(model).__name__ if len(missing) > 0: print(f'{log_name} Missing: {missing}') diff --git a/modules/img2img.py b/modules/img2img.py index f38cdc9c..fb5431ac 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -196,7 +196,6 @@ def img2img_function(id_task: str, request: gr.Request, mode: int, prompt: str, assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' p = StableDiffusionProcessingImg2Img( - sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples, outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids, prompt=prompt, diff --git a/modules/processing.py b/modules/processing.py index 2a156d97..8b1c01a8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,7 +30,7 @@ import modules.sd_vae as sd_vae from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType -from modules.sd_models import apply_token_merging +from modules.sd_models import apply_token_merging, forge_model_reload from modules_forge.utils import apply_circular_forge @@ -774,41 +774,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter def process_images(p: StableDiffusionProcessing) -> Processed: + forge_model_reload() + if p.scripts is not None: p.scripts.before_process(p) - stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data} + # backwards compatibility, fix sampler and scheduler if invalid + sd_samplers.fix_p_invalid_sampler_and_scheduler(p) - try: - # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint - # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards - if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: - p.override_settings.pop('sd_model_checkpoint', None) - sd_models.reload_model_weights() - - for k, v in p.override_settings.items(): - opts.set(k, v, is_api=True, run_callbacks=False) - - if k == 'sd_model_checkpoint': - sd_models.reload_model_weights() - - if k == 'sd_vae': - sd_vae.reload_vae_weights() - - # backwards compatibility, fix sampler and scheduler if invalid - sd_samplers.fix_p_invalid_sampler_and_scheduler(p) - - with profiling.Profiler(): - res = process_images_inner(p) - - finally: - # restore opts to original state - if p.override_settings_restore_afterwards: - for k, v in stored_opts.items(): - setattr(opts, k, v) - - if k == 'sd_vae': - sd_vae.reload_vae_weights() + with profiling.Profiler(): + res = process_images_inner(p) return res diff --git a/modules/sd_models.py b/modules/sd_models.py index 1204b1e6..b3d08c85 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -132,6 +132,12 @@ class CheckpointInfo: return self.shorthash + def __str__(self): + return str(dict(filename=self.filename, hash=self.hash)) + + def __repr__(self): + return str(dict(filename=self.filename, hash=self.hash)) + # try: # # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -379,8 +385,8 @@ def apply_alpha_schedule_override(sd_model, p=None): class SdModelData: def __init__(self): self.sd_model = None - self.loaded_sd_models = [] - self.was_loaded_at_least_once = False + self.forge_loading_parameters = {} + self.forge_hash = '' def get_sd_model(self): if self.sd_model is None: @@ -388,12 +394,8 @@ class SdModelData: return self.sd_model - def set_sd_model(self, v, already_loaded=False): + def set_sd_model(self, v): self.sd_model = v - if already_loaded: - sd_vae.base_vae = getattr(v, "base_vae", None) - sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None) - sd_vae.checkpoint_info = v.sd_checkpoint_info model_data = SdModelData() @@ -461,28 +463,45 @@ def apply_token_merging(sd_model, token_merging_ratio): @torch.no_grad() def forge_model_reload(): - checkpoint_info = select_checkpoint() + current_hash = str(model_data.forge_loading_parameters) + + if model_data.forge_hash == current_hash: + return model_data.sd_model + + print('Loading Model: ' + str(model_data.forge_loading_parameters)) timer = Timer() if model_data.sd_model: model_data.sd_model = None - model_data.loaded_sd_models = [] memory_management.unload_all_models() memory_management.soft_empty_cache() gc.collect() timer.record("unload existing model") - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + checkpoint_info = model_data.forge_loading_parameters['checkpoint_info'] + state_dict = load_torch_file(checkpoint_info.filename) + timer.record("load state dict") + + state_dict_vae = model_data.forge_loading_parameters.get('vae_filename', None) + + if state_dict_vae is not None: + state_dict_vae = load_torch_file(state_dict_vae) + + timer.record("load vae state dict") if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model checkpoints_loaded[checkpoint_info] = state_dict.copy() + timer.record("cache state dict") + + dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None) dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir dynamic_args['emphasis_name'] = opts.emphasis - sd_model = forge_loader(state_dict) + sd_model = forge_loader(state_dict, sd_vae=state_dict_vae) + del state_dict timer.record("forge model load") sd_model.extra_generation_params = {} @@ -492,22 +511,13 @@ def forge_model_reload(): sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - del state_dict - # clean up cache if limit is reached while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 - sd_vae.delete_base_vae() - sd_vae.clear_loaded_vae() - vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple() - sd_vae.load_vae(sd_model, vae_file, vae_source) - timer.record("load VAE") - model_data.set_sd_model(sd_model) - model_data.was_loaded_at_least_once = True script_callbacks.model_loaded_callback(sd_model) @@ -515,4 +525,6 @@ def forge_model_reload(): print(f"Model loaded in {timer.summary()}.") + model_data.forge_hash = current_hash + return sd_model diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 05c1cb1b..ef0ef168 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -187,87 +187,24 @@ def resolve_vae(checkpoint_file) -> VaeResolution: def load_vae_dict(filename, map_location): - return load_torch_file(filename) + pass def load_vae(model, vae_file=None, vae_source="from unknown source"): - global vae_dict, base_vae, loaded_vae_file - # save_settings = False - - cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 - - if vae_file: - if cache_enabled and vae_file in checkpoints_loaded: - # use vae checkpoint cache - print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") - store_base_vae(model) - _load_vae_dict(model, checkpoints_loaded[vae_file]) - else: - assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" - print(f"Loading VAE weights {vae_source}: {vae_file}") - store_base_vae(model) - - vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location) - _load_vae_dict(model, vae_dict_1) - - if cache_enabled: - # cache newly loaded vae - checkpoints_loaded[vae_file] = vae_dict_1.copy() - - # clean up cache if limit is reached - if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model - checkpoints_loaded.popitem(last=False) # LRU - - # If vae used is not in dict, update it - # It will be removed on refresh though - vae_opt = get_filename(vae_file) - if vae_opt not in vae_dict: - vae_dict[vae_opt] = vae_file - - elif loaded_vae_file: - restore_base_vae(model) - - loaded_vae_file = vae_file - model.base_vae = base_vae - model.loaded_vae_file = loaded_vae_file + raise NotImplementedError('Forge does not use this!') # don't call this from outside def _load_vae_dict(model, vae_dict_1): - model.first_stage_model.load_state_dict(vae_dict_1, strict=False) + pass def clear_loaded_vae(): - global loaded_vae_file - loaded_vae_file = None + pass unspecified = object() def reload_vae_weights(sd_model=None, vae_file=unspecified): - if not sd_model: - sd_model = shared.sd_model - - checkpoint_info = sd_model.sd_checkpoint_info - checkpoint_file = checkpoint_info.filename - - if vae_file == unspecified: - vae_file, vae_source = resolve_vae(checkpoint_file).tuple() - else: - vae_source = "from function argument" - - if loaded_vae_file == vae_file: - return - - # sd_hijack.model_hijack.undo_hijack(sd_model) - - load_vae(sd_model, vae_file, vae_source) - - # sd_hijack.model_hijack.hijack(sd_model) - - script_callbacks.model_loaded_callback(sd_model) - - print("VAE weights loaded.") - return sd_model + raise NotImplementedError('Forge does not use this!') diff --git a/modules/txt2img.py b/modules/txt2img.py index 3cc6599f..3ca0960a 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -19,7 +19,6 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne enable_hr = True p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, prompt=prompt, diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index 5ff1fe8f..4de52909 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -3,7 +3,6 @@ import gradio as gr from modules import shared_items, shared, ui_common, sd_models from modules import sd_vae as sd_vae_module -from modules_forge import main_thread from backend import args as backend_args @@ -59,7 +58,7 @@ def make_checkpoint_manager_ui(): ui_common.create_refresh_button(ui_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae") ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion in FP8", value=shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys())) - bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=lambda: main_thread.async_run(model_load_entry)) + bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=refresh_model_loading_parameters) ui_clip_skip = gr.Slider(label="Clip skip", value=shared.opts.CLIP_stop_at_last_layers, **{"minimum": 1, "maximum": 12, "step": 1}) bind_to_opts(ui_clip_skip, 'CLIP_stop_at_last_layers', save=False) @@ -67,12 +66,18 @@ def make_checkpoint_manager_ui(): return -def model_load_entry(): - backend_args.dynamic_args.update(dict( - forge_unet_storage_dtype=forge_unet_storage_dtype_options[shared.opts.forge_unet_storage_dtype] - )) +def refresh_model_loading_parameters(): + from modules.sd_models import select_checkpoint, model_data + + checkpoint_info = select_checkpoint() + vae_resolution = sd_vae_module.resolve_vae(checkpoint_info.filename) + + model_data.forge_loading_parameters = dict( + checkpoint_info=checkpoint_info, + vae_filename=vae_resolution.vae, + unet_storage_dtype=forge_unet_storage_dtype_options[shared.opts.forge_unet_storage_dtype] + ) - sd_models.forge_model_reload() return @@ -81,21 +86,22 @@ def checkpoint_change(ckpt_name): shared.opts.set('sd_model_checkpoint', ckpt_name) shared.opts.save(shared.config_filename) - model_load_entry() + refresh_model_loading_parameters() return def vae_change(vae_name): print(f'VAE Selected: {vae_name}') shared.opts.set('sd_vae', vae_name) - sd_vae_module.reload_vae_weights() + + refresh_model_loading_parameters() return def forge_main_entry(): - ui_checkpoint.change(lambda x: main_thread.async_run(checkpoint_change, x), inputs=[ui_checkpoint], show_progress=False) - ui_vae.change(lambda x: main_thread.async_run(vae_change, x), inputs=[ui_vae], show_progress=False) + ui_checkpoint.change(checkpoint_change, inputs=[ui_checkpoint], show_progress=False) + ui_vae.change(vae_change, inputs=[ui_vae], show_progress=False) # Load Model - main_thread.async_run(model_load_entry) + refresh_model_loading_parameters() return