mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-25 16:59:18 +00:00
@@ -45,7 +45,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()):
|
||||
model = IntegratedAutoencoderKL.from_config(config)
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
load_state_dict(model, state_dict, ignore_start='loss.')
|
||||
return model
|
||||
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
|
||||
from transformers import CLIPTextConfig, CLIPTextModel
|
||||
@@ -113,13 +113,16 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
return None
|
||||
|
||||
|
||||
def split_state_dict(sd):
|
||||
def split_state_dict(sd, sd_vae=None):
|
||||
guess = huggingface_guess.guess(sd)
|
||||
guess.clip_target = guess.clip_target(sd)
|
||||
|
||||
if sd_vae is not None:
|
||||
print(f'Using external VAE state dict: {len(sd_vae)}')
|
||||
|
||||
state_dict = {
|
||||
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
|
||||
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
|
||||
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) if sd_vae is None else sd_vae
|
||||
}
|
||||
|
||||
sd = guess.process_clip_state_dict(sd)
|
||||
@@ -138,8 +141,8 @@ def split_state_dict(sd):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forge_loader(sd):
|
||||
state_dicts, estimated_config = split_state_dict(sd)
|
||||
def forge_loader(sd, sd_vae=None):
|
||||
state_dicts, estimated_config = split_state_dict(sd, sd_vae=sd_vae)
|
||||
repo_name = estimated_config.huggingface_repo
|
||||
|
||||
local_path = os.path.join(dir_path, 'huggingface', repo_name)
|
||||
|
||||
Reference in New Issue
Block a user