Files
stable-diffusion-webui-forge/backend/loader.py
layerdiffusion 24cfce26dc add sd2 template
2024-08-05 11:49:45 -07:00

119 lines
4.2 KiB
Python

import os
import torch
import logging
import importlib
import huggingface_guess
from diffusers import DiffusionPipeline
from transformers import modeling_utils
from backend.state_dict import try_filter_state_dict, load_state_dict
from backend.operations import using_forge_operations
from backend.nn.vae import IntegratedAutoencoderKL
from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
from backend.nn.unet import IntegratedUNet2DConditionModel
from backend.diffusion_engine.sd15 import StableDiffusion
from backend.diffusion_engine.sd20 import StableDiffusion2
from backend.diffusion_engine.sdxl import StableDiffusionXL
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL]
logging.getLogger("diffusers").setLevel(logging.ERROR)
dir_path = os.path.dirname(__file__)
def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_path, state_dict):
config_path = os.path.join(repo_path, component_name)
if component_name in ['feature_extractor', 'safety_checker']:
return None
if lib_name in ['transformers', 'diffusers']:
if component_name in ['scheduler'] or component_name.startswith('tokenizer'):
cls = getattr(importlib.import_module(lib_name), cls_name)
return cls.from_pretrained(os.path.join(repo_path, component_name))
if cls_name in ['AutoencoderKL']:
config = IntegratedAutoencoderKL.load_config(config_path)
with using_forge_operations():
model = IntegratedAutoencoderKL.from_config(config)
load_state_dict(model, state_dict)
return model
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
config = CLIPTextConfig.from_pretrained(config_path)
with modeling_utils.no_init_weights():
with using_forge_operations():
model = IntegratedCLIP(config)
load_state_dict(model, state_dict, ignore_errors=[
'transformer.text_projection.weight',
'transformer.text_model.embeddings.position_ids',
'logit_scale'
], log_name=cls_name)
return model
if cls_name == 'UNet2DConditionModel':
with using_forge_operations():
model = IntegratedUNet2DConditionModel.from_config(guess.unet_config)
model._internal_dict = guess.unet_config
load_state_dict(model, state_dict)
return model
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
return None
def split_state_dict(sd):
guess = huggingface_guess.guess(sd)
state_dict = {
'unet': try_filter_state_dict(sd, ['model.diffusion_model.']),
'vae': try_filter_state_dict(sd, guess.vae_key_prefix)
}
sd = guess.process_clip_state_dict(sd)
guess.clip_target = guess.clip_target(sd)
for k, v in guess.clip_target.items():
state_dict[v] = try_filter_state_dict(sd, [k + '.'])
state_dict['ignore'] = sd
print_dict = {k: len(v) for k, v in state_dict.items()}
print(f'StateDict Keys: {print_dict}')
del state_dict['ignore']
return state_dict, guess
@torch.no_grad()
def forge_loader(sd):
state_dicts, estimated_config = split_state_dict(sd)
repo_name = estimated_config.huggingface_repo
local_path = os.path.join(dir_path, 'huggingface', repo_name)
config: dict = DiffusionPipeline.load_config(local_path)
huggingface_components = {}
for component_name, v in config.items():
if isinstance(v, list) and len(v) == 2:
lib_name, cls_name = v
component_sd = state_dicts.get(component_name, None)
component = load_huggingface_component(estimated_config, component_name, lib_name, cls_name, local_path, component_sd)
if component_sd is not None:
del state_dicts[component_name]
if component is not None:
huggingface_components[component_name] = component
for M in possible_models:
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)
print('Failed to recognize model type!')
return None