mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 18:51:31 +00:00
rework sd1.5 and sdxl from scratch
This commit is contained in:
@@ -1,16 +1,23 @@
|
||||
import os
|
||||
import torch
|
||||
import logging
|
||||
import importlib
|
||||
import huggingface_guess
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import modeling_utils
|
||||
from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace
|
||||
from backend.state_dict import try_filter_state_dict, load_state_dict
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.vae import IntegratedAutoencoderKL
|
||||
from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
|
||||
from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
|
||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
||||
|
||||
|
||||
possible_models = [StableDiffusion, StableDiffusionXL]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
dir_path = os.path.dirname(__file__)
|
||||
@@ -27,61 +34,84 @@ def load_component(guess, component_name, lib_name, cls_name, repo_path, state_d
|
||||
cls = getattr(importlib.import_module(lib_name), cls_name)
|
||||
return cls.from_pretrained(os.path.join(repo_path, component_name))
|
||||
if cls_name in ['AutoencoderKL']:
|
||||
sd = try_filter_state_dict(state_dict, ['first_stage_model.', 'vae.'])
|
||||
config = IntegratedAutoencoderKL.load_config(config_path)
|
||||
|
||||
with using_forge_operations():
|
||||
model = IntegratedAutoencoderKL.from_config(config)
|
||||
|
||||
load_state_dict(model, sd)
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
|
||||
if component_name == 'text_encoder':
|
||||
sd = try_filter_state_dict(state_dict, ['cond_stage_model.', 'conditioner.embedders.0.'])
|
||||
elif component_name == 'text_encoder_2':
|
||||
sd = try_filter_state_dict(state_dict, ['conditioner.embedders.1.'])
|
||||
else:
|
||||
raise ValueError(f"Wrong component_name: {component_name}")
|
||||
|
||||
if 'model.text_projection' in sd:
|
||||
sd = transformers_convert(sd, "model.", "transformer.text_model.", 32)
|
||||
sd = state_dict_key_replace(sd, {"model.text_projection": "text_projection",
|
||||
"model.text_projection.weight": "text_projection",
|
||||
"model.logit_scale": "logit_scale"})
|
||||
|
||||
config = CLIPTextConfig.from_pretrained(config_path)
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations():
|
||||
model = IntegratedCLIP(config)
|
||||
|
||||
load_state_dict(model, sd, ignore_errors=['text_projection', 'logit_scale',
|
||||
'transformer.text_model.embeddings.position_ids'])
|
||||
load_state_dict(model, state_dict, ignore_errors=[
|
||||
'transformer.text_projection.weight',
|
||||
'transformer.text_model.embeddings.position_ids',
|
||||
'logit_scale'
|
||||
], log_name=cls_name)
|
||||
|
||||
return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
sd = try_filter_state_dict(state_dict, ['model.diffusion_model.'])
|
||||
|
||||
with using_forge_operations():
|
||||
model = IntegratedUNet2DConditionModel.from_config(guess.unet_config)
|
||||
model._internal_dict = guess.unet_config
|
||||
|
||||
load_state_dict(model, sd)
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
|
||||
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
||||
return None
|
||||
|
||||
|
||||
def load_huggingface_components(sd):
|
||||
def split_state_dict(sd):
|
||||
guess = huggingface_guess.guess(sd)
|
||||
repo_name = guess.huggingface_repo
|
||||
|
||||
state_dict = {
|
||||
'unet': try_filter_state_dict(sd, ['model.diffusion_model.']),
|
||||
'vae': try_filter_state_dict(sd, guess.vae_key_prefix)
|
||||
}
|
||||
|
||||
sd = guess.process_clip_state_dict(sd)
|
||||
guess.clip_target = guess.clip_target(sd)
|
||||
|
||||
for k, v in guess.clip_target.items():
|
||||
state_dict[v] = try_filter_state_dict(sd, [k + '.'])
|
||||
|
||||
state_dict['ignore'] = sd
|
||||
|
||||
print_dict = {k: len(v) for k, v in state_dict.items()}
|
||||
print(f'StateDict Keys: {print_dict}')
|
||||
|
||||
del state_dict['ignore']
|
||||
|
||||
return state_dict, guess
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forge_loader(sd):
|
||||
state_dicts, estimated_config = split_state_dict(sd)
|
||||
repo_name = estimated_config.huggingface_repo
|
||||
|
||||
local_path = os.path.join(dir_path, 'huggingface', repo_name)
|
||||
config = DiffusionPipeline.load_config(local_path)
|
||||
result = {"repo_path": local_path}
|
||||
config: dict = DiffusionPipeline.load_config(local_path)
|
||||
huggingface_components = {}
|
||||
for component_name, v in config.items():
|
||||
if isinstance(v, list) and len(v) == 2:
|
||||
lib_name, cls_name = v
|
||||
component = load_component(guess, component_name, lib_name, cls_name, local_path, sd)
|
||||
component_sd = state_dicts.get(component_name, None)
|
||||
component = load_component(estimated_config, component_name, lib_name, cls_name, local_path, component_sd)
|
||||
if component_sd is not None:
|
||||
del state_dicts[component_name]
|
||||
if component is not None:
|
||||
result[component_name] = component
|
||||
return result
|
||||
huggingface_components[component_name] = component
|
||||
|
||||
for M in possible_models:
|
||||
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
|
||||
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)
|
||||
|
||||
print('Failed to recognize model type!')
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user