mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-27 01:38:53 +00:00
rework model loader
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import logging
|
||||
import importlib
|
||||
import huggingface_guess
|
||||
|
||||
from diffusers.loaders.single_file_utils import fetch_diffusers_config
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import modeling_utils
|
||||
from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace
|
||||
@@ -11,10 +12,11 @@ from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
|
||||
from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
dir_path = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
|
||||
def load_component(guess, component_name, lib_name, cls_name, repo_path, state_dict):
|
||||
config_path = os.path.join(repo_path, component_name)
|
||||
|
||||
if component_name in ['feature_extractor', 'safety_checker']:
|
||||
@@ -58,10 +60,9 @@ def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
|
||||
return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
sd = try_filter_state_dict(state_dict, ['model.diffusion_model.'])
|
||||
config = IntegratedUNet2DConditionModel.load_config(config_path)
|
||||
|
||||
with using_forge_operations():
|
||||
model = IntegratedUNet2DConditionModel.from_config(config)
|
||||
model = IntegratedUNet2DConditionModel.from_config(guess.unet_config)
|
||||
|
||||
load_state_dict(model, sd)
|
||||
return model
|
||||
@@ -70,20 +71,16 @@ def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
|
||||
return None
|
||||
|
||||
|
||||
def guess_repo_name_from_state_dict(sd):
|
||||
result = fetch_diffusers_config(sd)['pretrained_model_name_or_path']
|
||||
return result
|
||||
|
||||
|
||||
def load_huggingface_components(sd):
|
||||
repo_name = guess_repo_name_from_state_dict(sd)
|
||||
guess = huggingface_guess.guess(sd)
|
||||
repo_name = guess.huggingface_repo
|
||||
local_path = os.path.join(dir_path, 'huggingface', repo_name)
|
||||
config = DiffusionPipeline.load_config(local_path)
|
||||
result = {"repo_path": local_path}
|
||||
for component_name, v in config.items():
|
||||
if isinstance(v, list) and len(v) == 2:
|
||||
lib_name, cls_name = v
|
||||
component = load_component(component_name, lib_name, cls_name, local_path, sd)
|
||||
component = load_component(guess, component_name, lib_name, cls_name, local_path, sd)
|
||||
if component is not None:
|
||||
result[component_name] = component
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user