mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-26 17:54:02 +00:00
rework model loader
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import logging
|
||||
import importlib
|
||||
import huggingface_guess
|
||||
|
||||
from diffusers.loaders.single_file_utils import fetch_diffusers_config
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import modeling_utils
|
||||
from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace
|
||||
@@ -11,10 +12,11 @@ from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
|
||||
from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
dir_path = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
|
||||
def load_component(guess, component_name, lib_name, cls_name, repo_path, state_dict):
|
||||
config_path = os.path.join(repo_path, component_name)
|
||||
|
||||
if component_name in ['feature_extractor', 'safety_checker']:
|
||||
@@ -58,10 +60,9 @@ def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
|
||||
return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
sd = try_filter_state_dict(state_dict, ['model.diffusion_model.'])
|
||||
config = IntegratedUNet2DConditionModel.load_config(config_path)
|
||||
|
||||
with using_forge_operations():
|
||||
model = IntegratedUNet2DConditionModel.from_config(config)
|
||||
model = IntegratedUNet2DConditionModel.from_config(guess.unet_config)
|
||||
|
||||
load_state_dict(model, sd)
|
||||
return model
|
||||
@@ -70,20 +71,16 @@ def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
|
||||
return None
|
||||
|
||||
|
||||
def guess_repo_name_from_state_dict(sd):
|
||||
result = fetch_diffusers_config(sd)['pretrained_model_name_or_path']
|
||||
return result
|
||||
|
||||
|
||||
def load_huggingface_components(sd):
|
||||
repo_name = guess_repo_name_from_state_dict(sd)
|
||||
guess = huggingface_guess.guess(sd)
|
||||
repo_name = guess.huggingface_repo
|
||||
local_path = os.path.join(dir_path, 'huggingface', repo_name)
|
||||
config = DiffusionPipeline.load_config(local_path)
|
||||
result = {"repo_path": local_path}
|
||||
for component_name, v in config.items():
|
||||
if isinstance(v, list) and len(v) == 2:
|
||||
lib_name, cls_name = v
|
||||
component = load_component(component_name, lib_name, cls_name, local_path, sd)
|
||||
component = load_component(guess, component_name, lib_name, cls_name, local_path, sd)
|
||||
if component is not None:
|
||||
result[component_name] = component
|
||||
return result
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
import math
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from backend.attention import attention_function
|
||||
|
||||
unet_initial_dtype = torch.float16
|
||||
unet_initial_device = None
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
|
||||
|
||||
def checkpoint(f, args, parameters, enable=False):
|
||||
@@ -121,7 +115,7 @@ class GEGLU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
return x * torch.nn.functional.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
@@ -162,6 +156,7 @@ class CrossAttention(nn.Module):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
@@ -178,6 +173,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.ff_in = ff_in or inner_dim is not None
|
||||
|
||||
if inner_dim is None:
|
||||
inner_dim = dim
|
||||
|
||||
@@ -204,7 +200,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.d_head = d_head
|
||||
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
return checkpoint(self._forward, (x, context, transformer_options), None, self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
# Stolen from ComfyUI with some modifications
|
||||
@@ -404,7 +400,7 @@ class Upsample(nn.Module):
|
||||
shape[0] = output_shape[2]
|
||||
shape[1] = output_shape[3]
|
||||
|
||||
x = F.interpolate(x, size=shape, mode="nearest")
|
||||
x = torch.nn.functional.interpolate(x, size=shape, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
@@ -513,9 +509,7 @@ class ResBlock(TimestepBlock):
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, emb, transformer_options={}):
|
||||
return checkpoint(
|
||||
self._forward, (x, emb, transformer_options), self.parameters(), self.use_checkpoint
|
||||
)
|
||||
return checkpoint(self._forward, (x, emb, transformer_options), None, self.use_checkpoint)
|
||||
|
||||
def _forward(self, x, emb, transformer_options={}):
|
||||
if self.updown:
|
||||
@@ -570,119 +564,37 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
|
||||
config_name = 'config.json'
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4,
|
||||
center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D",),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1, mid_block_scale_factor: float = 1, dropout: float = 0.0,
|
||||
act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
||||
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
||||
encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False, class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None, upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False,
|
||||
resnet_out_scale_factor: float = 1.0, time_embedding_type: str = "positional",
|
||||
time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None,
|
||||
timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None,
|
||||
conv_in_kernel: int = 3, conv_out_kernel: int = 3,
|
||||
projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default",
|
||||
class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None,
|
||||
cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_spatial_transformer=False,
|
||||
transformer_depth=1,
|
||||
context_dim=None,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
in_channels = in_channels
|
||||
out_channels = out_channels
|
||||
model_channels = block_out_channels[0]
|
||||
num_res_blocks = [layers_per_block] * len(block_out_channels)
|
||||
dropout = dropout
|
||||
channel_mult = [x // model_channels for x in block_out_channels]
|
||||
conv_resample = True
|
||||
dims = 2
|
||||
num_classes = None
|
||||
use_checkpoint = False
|
||||
adm_in_channels = None
|
||||
num_heads = -1
|
||||
num_head_channels = -1
|
||||
num_heads_upsample = -1
|
||||
use_scale_shift_norm = False
|
||||
resblock_updown = False
|
||||
use_spatial_transformer = True
|
||||
transformer_depth = []
|
||||
transformer_depth_output = []
|
||||
transformer_depth_middle = 1
|
||||
context_dim = cross_attention_dim
|
||||
disable_self_attentions: list = None
|
||||
num_attention_blocks: list = None
|
||||
disable_middle_self_attn = False
|
||||
use_linear_in_transformer = use_linear_projection
|
||||
|
||||
for i, d in enumerate(down_block_types):
|
||||
if 'attn' in d.lower():
|
||||
current_transformer_depth = 1
|
||||
if isinstance(transformer_layers_per_block, list) and len(transformer_layers_per_block) > i:
|
||||
current_transformer_depth = transformer_layers_per_block[i]
|
||||
transformer_depth += [current_transformer_depth] * 2
|
||||
transformer_depth_output += [current_transformer_depth] * 3
|
||||
else:
|
||||
transformer_depth += [0] * 2
|
||||
transformer_depth_output += [0] * 3
|
||||
|
||||
if transformer_depth_output[-1] > 1:
|
||||
transformer_depth_middle = transformer_depth_output[-1]
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
num_heads = attention_head_dim
|
||||
elif isinstance(attention_head_dim, list):
|
||||
num_head_channels = model_channels // attention_head_dim[0]
|
||||
else:
|
||||
raise ValueError('Wrong attention heads!')
|
||||
|
||||
if isinstance(projection_class_embeddings_input_dim, int) and projection_class_embeddings_input_dim > 0:
|
||||
num_classes = 'sequential'
|
||||
adm_in_channels = projection_class_embeddings_input_dim
|
||||
|
||||
dtype = unet_initial_dtype
|
||||
device = unet_initial_device
|
||||
|
||||
self.legacy_config = dict(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
model_channels=model_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
num_classes=num_classes,
|
||||
dtype=dtype,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
use_spatial_transformer=use_spatial_transformer,
|
||||
transformer_depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attentions=disable_self_attentions,
|
||||
num_attention_blocks=num_attention_blocks,
|
||||
disable_middle_self_attn=disable_middle_self_attn,
|
||||
use_linear_in_transformer=use_linear_in_transformer,
|
||||
adm_in_channels=adm_in_channels,
|
||||
transformer_depth_middle=transformer_depth_middle,
|
||||
transformer_depth_output=transformer_depth_output,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer
|
||||
|
||||
@@ -699,12 +611,11 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("Bad num_res_blocks")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
if disable_self_attentions is not None:
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
|
||||
@@ -988,7 +899,7 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
|
||||
for p in patch:
|
||||
h, hsp = p(h, hsp, transformer_options)
|
||||
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
h = torch.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
if len(hs) > 0:
|
||||
output_shape = hs[-1].shape
|
||||
|
||||
@@ -401,7 +401,7 @@ def prepare_environment():
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "0c556001140dd4f141cea56f5679829602e06c8b")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "aebabb94eaaa1a26a3b37128d1c079838c134623")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user