diff --git a/backend/loader.py b/backend/loader.py index 32d1f8b0..51f6b6bf 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -1,7 +1,8 @@ import os +import logging import importlib +import huggingface_guess -from diffusers.loaders.single_file_utils import fetch_diffusers_config from diffusers import DiffusionPipeline from transformers import modeling_utils from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace @@ -11,10 +12,11 @@ from backend.nn.clip import IntegratedCLIP, CLIPTextConfig from backend.nn.unet import IntegratedUNet2DConditionModel +logging.getLogger("diffusers").setLevel(logging.ERROR) dir_path = os.path.dirname(__file__) -def load_component(component_name, lib_name, cls_name, repo_path, state_dict): +def load_component(guess, component_name, lib_name, cls_name, repo_path, state_dict): config_path = os.path.join(repo_path, component_name) if component_name in ['feature_extractor', 'safety_checker']: @@ -58,10 +60,9 @@ def load_component(component_name, lib_name, cls_name, repo_path, state_dict): return model if cls_name == 'UNet2DConditionModel': sd = try_filter_state_dict(state_dict, ['model.diffusion_model.']) - config = IntegratedUNet2DConditionModel.load_config(config_path) with using_forge_operations(): - model = IntegratedUNet2DConditionModel.from_config(config) + model = IntegratedUNet2DConditionModel.from_config(guess.unet_config) load_state_dict(model, sd) return model @@ -70,20 +71,16 @@ def load_component(component_name, lib_name, cls_name, repo_path, state_dict): return None -def guess_repo_name_from_state_dict(sd): - result = fetch_diffusers_config(sd)['pretrained_model_name_or_path'] - return result - - def load_huggingface_components(sd): - repo_name = guess_repo_name_from_state_dict(sd) + guess = huggingface_guess.guess(sd) + repo_name = guess.huggingface_repo local_path = os.path.join(dir_path, 'huggingface', repo_name) config = DiffusionPipeline.load_config(local_path) result = {"repo_path": local_path} for component_name, v in config.items(): if isinstance(v, list) and len(v) == 2: lib_name, cls_name = v - component = load_component(component_name, lib_name, cls_name, local_path, sd) + component = load_component(guess, component_name, lib_name, cls_name, local_path, sd) if component is not None: result[component_name] = component return result diff --git a/backend/nn/unet.py b/backend/nn/unet.py index 954db913..d99c8ee8 100644 --- a/backend/nn/unet.py +++ b/backend/nn/unet.py @@ -1,16 +1,10 @@ import math import torch -import torch as th -import torch.nn.functional as F -from typing import Optional, Tuple, Union -from diffusers.configuration_utils import ConfigMixin, register_to_config from torch import nn from einops import rearrange, repeat from backend.attention import attention_function - -unet_initial_dtype = torch.float16 -unet_initial_device = None +from diffusers.configuration_utils import ConfigMixin, register_to_config def checkpoint(f, args, parameters, enable=False): @@ -121,7 +115,7 @@ class GEGLU(nn.Module): def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) + return x * torch.nn.functional.gelu(gate) class FeedForward(nn.Module): @@ -162,6 +156,7 @@ class CrossAttention(nn.Module): q = self.to_q(x) context = default(context, x) k = self.to_k(context) + if value is not None: v = self.to_v(value) del value @@ -178,6 +173,7 @@ class BasicTransformerBlock(nn.Module): super().__init__() self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: inner_dim = dim @@ -204,7 +200,7 @@ class BasicTransformerBlock(nn.Module): self.d_head = d_head def forward(self, x, context=None, transformer_options={}): - return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) + return checkpoint(self._forward, (x, context, transformer_options), None, self.checkpoint) def _forward(self, x, context=None, transformer_options={}): # Stolen from ComfyUI with some modifications @@ -404,7 +400,7 @@ class Upsample(nn.Module): shape[0] = output_shape[2] shape[1] = output_shape[3] - x = F.interpolate(x, size=shape, mode="nearest") + x = torch.nn.functional.interpolate(x, size=shape, mode="nearest") if self.use_conv: x = self.conv(x) return x @@ -513,9 +509,7 @@ class ResBlock(TimestepBlock): self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) def forward(self, x, emb, transformer_options={}): - return checkpoint( - self._forward, (x, emb, transformer_options), self.parameters(), self.use_checkpoint - ) + return checkpoint(self._forward, (x, emb, transformer_options), None, self.use_checkpoint) def _forward(self, x, emb, transformer_options={}): if self.updown: @@ -570,119 +564,37 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin): config_name = 'config.json' @register_to_config - def __init__(self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, - center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D",), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ( - "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, - downsample_padding: int = 1, mid_block_scale_factor: float = 1, dropout: float = 0.0, - act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, - cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, - encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, - attention_head_dim: Union[int, Tuple[int]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, - use_linear_projection: bool = False, class_embed_type: Optional[str] = None, - addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, - num_class_embeds: Optional[int] = None, upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, - resnet_out_scale_factor: float = 1.0, time_embedding_type: str = "positional", - time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, - timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, - conv_in_kernel: int = 3, conv_out_kernel: int = 3, - projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", - class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, - cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, *args, **kwargs): + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + num_heads=-1, + num_head_channels=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_spatial_transformer=False, + transformer_depth=1, + context_dim=None, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + adm_in_channels=None, + transformer_depth_middle=None, + transformer_depth_output=None, + dtype=None, + device=None, + ): super().__init__() - in_channels = in_channels - out_channels = out_channels - model_channels = block_out_channels[0] - num_res_blocks = [layers_per_block] * len(block_out_channels) - dropout = dropout - channel_mult = [x // model_channels for x in block_out_channels] - conv_resample = True - dims = 2 - num_classes = None - use_checkpoint = False - adm_in_channels = None - num_heads = -1 - num_head_channels = -1 - num_heads_upsample = -1 - use_scale_shift_norm = False - resblock_updown = False - use_spatial_transformer = True - transformer_depth = [] - transformer_depth_output = [] - transformer_depth_middle = 1 - context_dim = cross_attention_dim - disable_self_attentions: list = None - num_attention_blocks: list = None - disable_middle_self_attn = False - use_linear_in_transformer = use_linear_projection - - for i, d in enumerate(down_block_types): - if 'attn' in d.lower(): - current_transformer_depth = 1 - if isinstance(transformer_layers_per_block, list) and len(transformer_layers_per_block) > i: - current_transformer_depth = transformer_layers_per_block[i] - transformer_depth += [current_transformer_depth] * 2 - transformer_depth_output += [current_transformer_depth] * 3 - else: - transformer_depth += [0] * 2 - transformer_depth_output += [0] * 3 - - if transformer_depth_output[-1] > 1: - transformer_depth_middle = transformer_depth_output[-1] - - if isinstance(attention_head_dim, int): - num_heads = attention_head_dim - elif isinstance(attention_head_dim, list): - num_head_channels = model_channels // attention_head_dim[0] - else: - raise ValueError('Wrong attention heads!') - - if isinstance(projection_class_embeddings_input_dim, int) and projection_class_embeddings_input_dim > 0: - num_classes = 'sequential' - adm_in_channels = projection_class_embeddings_input_dim - - dtype = unet_initial_dtype - device = unet_initial_device - - self.legacy_config = dict( - in_channels=in_channels, - out_channels=out_channels, - model_channels=model_channels, - num_res_blocks=num_res_blocks, - dropout=dropout, - channel_mult=channel_mult, - conv_resample=conv_resample, - dims=dims, - num_classes=num_classes, - dtype=dtype, - num_heads=num_heads, - num_head_channels=num_head_channels, - num_heads_upsample=num_heads_upsample, - use_scale_shift_norm=use_scale_shift_norm, - resblock_updown=resblock_updown, - use_spatial_transformer=use_spatial_transformer, - transformer_depth=transformer_depth, - context_dim=context_dim, - disable_self_attentions=disable_self_attentions, - num_attention_blocks=num_attention_blocks, - disable_middle_self_attn=disable_middle_self_attn, - use_linear_in_transformer=use_linear_in_transformer, - adm_in_channels=adm_in_channels, - transformer_depth_middle=transformer_depth_middle, - transformer_depth_output=transformer_depth_output, - device=device, - ) - if context_dim is not None: assert use_spatial_transformer @@ -699,12 +611,11 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin): if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: - if len(num_res_blocks) != len(channel_mult): - raise ValueError("Bad num_res_blocks") self.num_res_blocks = num_res_blocks if disable_self_attentions is not None: assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) @@ -988,7 +899,7 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin): for p in patch: h, hsp = p(h, hsp, transformer_options) - h = th.cat([h, hsp], dim=1) + h = torch.cat([h, hsp], dim=1) del hsp if len(hs) > 0: output_shape = hs[-1].shape diff --git a/modules/launch_utils.py b/modules/launch_utils.py index da7a788a..5b31861c 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -401,7 +401,7 @@ def prepare_environment(): stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") - huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "0c556001140dd4f141cea56f5679829602e06c8b") + huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "aebabb94eaaa1a26a3b37128d1c079838c134623") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") try: