Files
stable-diffusion-webui-forge/backend/loader.py
2024-08-03 17:01:40 -07:00

87 lines
3.8 KiB
Python

import os
import logging
import importlib
import huggingface_guess
from diffusers import DiffusionPipeline
from transformers import modeling_utils
from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace
from backend.operations import using_forge_operations
from backend.nn.vae import IntegratedAutoencoderKL
from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
from backend.nn.unet import IntegratedUNet2DConditionModel
logging.getLogger("diffusers").setLevel(logging.ERROR)
dir_path = os.path.dirname(__file__)
def load_component(guess, component_name, lib_name, cls_name, repo_path, state_dict):
config_path = os.path.join(repo_path, component_name)
if component_name in ['feature_extractor', 'safety_checker']:
return None
if lib_name in ['transformers', 'diffusers']:
if component_name in ['scheduler'] or component_name.startswith('tokenizer'):
cls = getattr(importlib.import_module(lib_name), cls_name)
return cls.from_pretrained(os.path.join(repo_path, component_name))
if cls_name in ['AutoencoderKL']:
sd = try_filter_state_dict(state_dict, ['first_stage_model.', 'vae.'])
config = IntegratedAutoencoderKL.load_config(config_path)
with using_forge_operations():
model = IntegratedAutoencoderKL.from_config(config)
load_state_dict(model, sd)
return model
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
if component_name == 'text_encoder':
sd = try_filter_state_dict(state_dict, ['cond_stage_model.', 'conditioner.embedders.0.'])
elif component_name == 'text_encoder_2':
sd = try_filter_state_dict(state_dict, ['conditioner.embedders.1.'])
else:
raise ValueError(f"Wrong component_name: {component_name}")
if 'model.text_projection' in sd:
sd = transformers_convert(sd, "model.", "transformer.text_model.", 32)
sd = state_dict_key_replace(sd, {"model.text_projection": "text_projection",
"model.text_projection.weight": "text_projection",
"model.logit_scale": "logit_scale"})
config = CLIPTextConfig.from_pretrained(config_path)
with modeling_utils.no_init_weights():
with using_forge_operations():
model = IntegratedCLIP(config)
load_state_dict(model, sd, ignore_errors=['text_projection', 'logit_scale',
'transformer.text_model.embeddings.position_ids'])
return model
if cls_name == 'UNet2DConditionModel':
sd = try_filter_state_dict(state_dict, ['model.diffusion_model.'])
with using_forge_operations():
model = IntegratedUNet2DConditionModel.from_config(guess.unet_config)
load_state_dict(model, sd)
return model
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
return None
def load_huggingface_components(sd):
guess = huggingface_guess.guess(sd)
repo_name = guess.huggingface_repo
local_path = os.path.join(dir_path, 'huggingface', repo_name)
config = DiffusionPipeline.load_config(local_path)
result = {"repo_path": local_path}
for component_name, v in config.items():
if isinstance(v, list) and len(v) == 2:
lib_name, cls_name = v
component = load_component(guess, component_name, lib_name, cls_name, local_path, sd)
if component is not None:
result[component_name] = component
return result