rework model loader

This commit is contained in:
layerdiffusion
2024-08-03 16:23:32 -07:00
parent 430482d1a0
commit fb3052350b
3 changed files with 47 additions and 139 deletions

View File

@@ -1,7 +1,8 @@
import os
import logging
import importlib
import huggingface_guess
from diffusers.loaders.single_file_utils import fetch_diffusers_config
from diffusers import DiffusionPipeline
from transformers import modeling_utils
from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace
@@ -11,10 +12,11 @@ from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
from backend.nn.unet import IntegratedUNet2DConditionModel
logging.getLogger("diffusers").setLevel(logging.ERROR)
dir_path = os.path.dirname(__file__)
def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
def load_component(guess, component_name, lib_name, cls_name, repo_path, state_dict):
config_path = os.path.join(repo_path, component_name)
if component_name in ['feature_extractor', 'safety_checker']:
@@ -58,10 +60,9 @@ def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
return model
if cls_name == 'UNet2DConditionModel':
sd = try_filter_state_dict(state_dict, ['model.diffusion_model.'])
config = IntegratedUNet2DConditionModel.load_config(config_path)
with using_forge_operations():
model = IntegratedUNet2DConditionModel.from_config(config)
model = IntegratedUNet2DConditionModel.from_config(guess.unet_config)
load_state_dict(model, sd)
return model
@@ -70,20 +71,16 @@ def load_component(component_name, lib_name, cls_name, repo_path, state_dict):
return None
def guess_repo_name_from_state_dict(sd):
result = fetch_diffusers_config(sd)['pretrained_model_name_or_path']
return result
def load_huggingface_components(sd):
repo_name = guess_repo_name_from_state_dict(sd)
guess = huggingface_guess.guess(sd)
repo_name = guess.huggingface_repo
local_path = os.path.join(dir_path, 'huggingface', repo_name)
config = DiffusionPipeline.load_config(local_path)
result = {"repo_path": local_path}
for component_name, v in config.items():
if isinstance(v, list) and len(v) == 2:
lib_name, cls_name = v
component = load_component(component_name, lib_name, cls_name, local_path, sd)
component = load_component(guess, component_name, lib_name, cls_name, local_path, sd)
if component is not None:
result[component_name] = component
return result

View File

@@ -1,16 +1,10 @@
import math
import torch
import torch as th
import torch.nn.functional as F
from typing import Optional, Tuple, Union
from diffusers.configuration_utils import ConfigMixin, register_to_config
from torch import nn
from einops import rearrange, repeat
from backend.attention import attention_function
unet_initial_dtype = torch.float16
unet_initial_device = None
from diffusers.configuration_utils import ConfigMixin, register_to_config
def checkpoint(f, args, parameters, enable=False):
@@ -121,7 +115,7 @@ class GEGLU(nn.Module):
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
return x * torch.nn.functional.gelu(gate)
class FeedForward(nn.Module):
@@ -162,6 +156,7 @@ class CrossAttention(nn.Module):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
if value is not None:
v = self.to_v(value)
del value
@@ -178,6 +173,7 @@ class BasicTransformerBlock(nn.Module):
super().__init__()
self.ff_in = ff_in or inner_dim is not None
if inner_dim is None:
inner_dim = dim
@@ -204,7 +200,7 @@ class BasicTransformerBlock(nn.Module):
self.d_head = d_head
def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
return checkpoint(self._forward, (x, context, transformer_options), None, self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
# Stolen from ComfyUI with some modifications
@@ -404,7 +400,7 @@ class Upsample(nn.Module):
shape[0] = output_shape[2]
shape[1] = output_shape[3]
x = F.interpolate(x, size=shape, mode="nearest")
x = torch.nn.functional.interpolate(x, size=shape, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
@@ -513,9 +509,7 @@ class ResBlock(TimestepBlock):
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
def forward(self, x, emb, transformer_options={}):
return checkpoint(
self._forward, (x, emb, transformer_options), self.parameters(), self.use_checkpoint
)
return checkpoint(self._forward, (x, emb, transformer_options), None, self.use_checkpoint)
def _forward(self, x, emb, transformer_options={}):
if self.updown:
@@ -570,119 +564,37 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
config_name = 'config.json'
@register_to_config
def __init__(self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4,
center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D",),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = (
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, mid_block_scale_factor: float = 1, dropout: float = 0.0,
act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False,
use_linear_projection: bool = False, class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False,
resnet_out_scale_factor: float = 1.0, time_embedding_type: str = "positional",
time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3, conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default",
class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, *args, **kwargs):
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
num_heads=-1,
num_head_channels=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_spatial_transformer=False,
transformer_depth=1,
context_dim=None,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
dtype=None,
device=None,
):
super().__init__()
in_channels = in_channels
out_channels = out_channels
model_channels = block_out_channels[0]
num_res_blocks = [layers_per_block] * len(block_out_channels)
dropout = dropout
channel_mult = [x // model_channels for x in block_out_channels]
conv_resample = True
dims = 2
num_classes = None
use_checkpoint = False
adm_in_channels = None
num_heads = -1
num_head_channels = -1
num_heads_upsample = -1
use_scale_shift_norm = False
resblock_updown = False
use_spatial_transformer = True
transformer_depth = []
transformer_depth_output = []
transformer_depth_middle = 1
context_dim = cross_attention_dim
disable_self_attentions: list = None
num_attention_blocks: list = None
disable_middle_self_attn = False
use_linear_in_transformer = use_linear_projection
for i, d in enumerate(down_block_types):
if 'attn' in d.lower():
current_transformer_depth = 1
if isinstance(transformer_layers_per_block, list) and len(transformer_layers_per_block) > i:
current_transformer_depth = transformer_layers_per_block[i]
transformer_depth += [current_transformer_depth] * 2
transformer_depth_output += [current_transformer_depth] * 3
else:
transformer_depth += [0] * 2
transformer_depth_output += [0] * 3
if transformer_depth_output[-1] > 1:
transformer_depth_middle = transformer_depth_output[-1]
if isinstance(attention_head_dim, int):
num_heads = attention_head_dim
elif isinstance(attention_head_dim, list):
num_head_channels = model_channels // attention_head_dim[0]
else:
raise ValueError('Wrong attention heads!')
if isinstance(projection_class_embeddings_input_dim, int) and projection_class_embeddings_input_dim > 0:
num_classes = 'sequential'
adm_in_channels = projection_class_embeddings_input_dim
dtype = unet_initial_dtype
device = unet_initial_device
self.legacy_config = dict(
in_channels=in_channels,
out_channels=out_channels,
model_channels=model_channels,
num_res_blocks=num_res_blocks,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
num_classes=num_classes,
dtype=dtype,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
use_spatial_transformer=use_spatial_transformer,
transformer_depth=transformer_depth,
context_dim=context_dim,
disable_self_attentions=disable_self_attentions,
num_attention_blocks=num_attention_blocks,
disable_middle_self_attn=disable_middle_self_attn,
use_linear_in_transformer=use_linear_in_transformer,
adm_in_channels=adm_in_channels,
transformer_depth_middle=transformer_depth_middle,
transformer_depth_output=transformer_depth_output,
device=device,
)
if context_dim is not None:
assert use_spatial_transformer
@@ -699,12 +611,11 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("Bad num_res_blocks")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
@@ -988,7 +899,7 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
for p in patch:
h, hsp = p(h, hsp, transformer_options)
h = th.cat([h, hsp], dim=1)
h = torch.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape

View File

@@ -401,7 +401,7 @@ def prepare_environment():
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "0c556001140dd4f141cea56f5679829602e06c8b")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "aebabb94eaaa1a26a3b37128d1c079838c134623")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try: