Load Model only when click Generate

#964
This commit is contained in:
lllyasviel
2024-08-08 14:51:13 -07:00
committed by GitHub
parent ce3f0f86b4
commit 6921420b3f
8 changed files with 76 additions and 140 deletions

View File

@@ -45,7 +45,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()):
model = IntegratedAutoencoderKL.from_config(config)
load_state_dict(model, state_dict)
load_state_dict(model, state_dict, ignore_start='loss.')
return model
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
from transformers import CLIPTextConfig, CLIPTextModel
@@ -113,13 +113,16 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
return None
def split_state_dict(sd):
def split_state_dict(sd, sd_vae=None):
guess = huggingface_guess.guess(sd)
guess.clip_target = guess.clip_target(sd)
if sd_vae is not None:
print(f'Using external VAE state dict: {len(sd_vae)}')
state_dict = {
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) if sd_vae is None else sd_vae
}
sd = guess.process_clip_state_dict(sd)
@@ -138,8 +141,8 @@ def split_state_dict(sd):
@torch.no_grad()
def forge_loader(sd):
state_dicts, estimated_config = split_state_dict(sd)
def forge_loader(sd, sd_vae=None):
state_dicts, estimated_config = split_state_dict(sd, sd_vae=sd_vae)
repo_name = estimated_config.huggingface_repo
local_path = os.path.join(dir_path, 'huggingface', repo_name)