diff --git a/modules/sd_models.py b/modules/sd_models.py index 636bc518..6774918a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -599,92 +599,26 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): else: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + if shared.opts.sd_checkpoint_cache > 0: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = state_dict.copy() + sd_model = forge_loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict) - checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) - - timer.record("find config") - - sd_config = OmegaConf.load(checkpoint_config) - repair_config(sd_config) - - timer.record("load config") - - if hasattr(sd_config.model.params, 'network_config'): - sd_config.model.params.network_config.target = 'modules_forge.forge_loader.FakeObject' - - if hasattr(sd_config.model.params, 'unet_config'): - sd_config.model.params.unet_config.target = 'modules_forge.forge_loader.FakeObject' - - if hasattr(sd_config.model.params, 'first_stage_config'): - sd_config.model.params.first_stage_config.target = 'modules_forge.forge_loader.FakeObject' - - print(f"Creating model from config: {checkpoint_config}") - - sd_model = None - try: - with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): - with forge_ops.use_patched_ops(manual_cast): - sd_model = instantiate_from_config(sd_config.model) - - except Exception as e: - errors.display(e, "creating model quickly", full_traceback=True) - - if sd_model is None: - print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) - - with forge_ops.use_patched_ops(manual_cast): - sd_model = instantiate_from_config(sd_config.model) - - sd_model.used_config = checkpoint_config - - timer.record("create model") - - state_dict_for_a1111 = {k: v for k, v in state_dict.items() if not k.startswith('model.diffusion_model.') and not k.startswith('first_stage_model.')} - state_dict_for_forge = {k: v for k, v in state_dict.items()} del state_dict - unet_patcher, vae_patcher = forge_loader.load_unet_and_vae(state_dict_for_forge) - sd_model.first_stage_model = vae_patcher.first_stage_model - sd_model.model.diffusion_model = unet_patcher.model.diffusion_model - sd_model.unet_patcher = unet_patcher - sd_model.model.diffusion_model.patcher = unet_patcher - sd_model.vae_patcher = vae_patcher - sd_model.first_stage_model.patcher = vae_patcher - timer.record("create unet patcher") - del state_dict_for_forge + # clean up cache if limit is reached + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) - def patched_decode_first_stage(sample): - sample = unet_patcher.model.model_config.latent_format.process_out(sample) - return vae_patcher.decode(sample).movedim(-1, 1) * 2.0 - 1.0 + shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 - sd_model.decode_first_stage = patched_decode_first_stage + sd_vae.delete_base_vae() + sd_vae.clear_loaded_vae() + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple() + sd_vae.load_vae(sd_model, vae_file, vae_source) + timer.record("load VAE") - load_model_weights(sd_model, checkpoint_info, state_dict_for_a1111, timer) - del state_dict_for_a1111 - timer.record("load weights from state dict") - - current_clip = sd_model.conditioner if hasattr(sd_model, 'conditioner') else sd_model.cond_stage_model - clip_load_device = model_management.text_encoder_device() - clip_offload_device = model_management.text_encoder_offload_device() - clip_dtype = model_management.text_encoder_dtype() - - current_clip.to(clip_dtype) - clip_patcher = ldm_patched.modules.model_patcher.ModelPatcher( - current_clip, - load_device=clip_load_device, - offload_device=clip_offload_device - ) - sd_model.clip_patcher = clip_patcher - current_clip.patcher = clip_patcher - timer.record("create clip patcher") - - sd_hijack.model_hijack.hijack(sd_model) - - timer.record("hijack") - - sd_model.eval() model_data.set_sd_model(sd_model) model_data.was_loaded_at_least_once = True @@ -696,7 +630,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("scripts callbacks") - with devices.autocast(), torch.no_grad(): + with torch.no_grad(): sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model) timer.record("calculate empty prompt")