diff --git a/backend/misc/image_resize.py b/backend/misc/image_resize.py new file mode 100644 index 00000000..39906708 --- /dev/null +++ b/backend/misc/image_resize.py @@ -0,0 +1,113 @@ +import torch +import numpy as np + +from PIL import Image + + +def bislerp(samples, width, height): + def slerp(b1, b2, r): + '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' + + c = b1.shape[-1] + + # norms + b1_norms = torch.norm(b1, dim=-1, keepdim=True) + b2_norms = torch.norm(b2, dim=-1, keepdim=True) + + # normalize + b1_normalized = b1 / b1_norms + b2_normalized = b2 / b2_norms + + # zero when norms are zero + b1_normalized[b1_norms.expand(-1, c) == 0.0] = 0.0 + b2_normalized[b2_norms.expand(-1, c) == 0.0] = 0.0 + + # slerp + dot = (b1_normalized * b2_normalized).sum(1) + omega = torch.acos(dot) + so = torch.sin(omega) + + # technically not mathematically correct, but more pleasing? + res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(1) * b1_normalized + (torch.sin(r.squeeze(1) * omega) / so).unsqueeze(1) * b2_normalized + res *= (b1_norms * (1.0 - r) + b2_norms * r).expand(-1, c) + + # edge cases for same or polar opposites + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1] + return res + + def generate_bilinear_data(length_old, length_new, device): + coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) + coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") + ratios = coords_1 - coords_1.floor() + coords_1 = coords_1.to(torch.int64) + + coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) + 1 + coords_2[:, :, :, -1] -= 1 + coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") + coords_2 = coords_2.to(torch.int64) + return ratios, coords_1, coords_2 + + orig_dtype = samples.dtype + samples = samples.float() + n, c, h, w = samples.shape + h_new, w_new = (height, width) + + # linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device) + coords_1 = coords_1.expand((n, c, h, -1)) + coords_2 = coords_2.expand((n, c, h, -1)) + ratios = ratios.expand((n, 1, h, -1)) + + pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c)) + pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c)) + ratios = ratios.movedim(1, -1).reshape((-1, 1)) + + result = slerp(pass_1, pass_2, ratios) + result = result.reshape(n, h, w_new, c).movedim(-1, 1) + + # linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device) + coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new)) + coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new)) + ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new)) + + pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c)) + pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c)) + ratios = ratios.movedim(1, -1).reshape((-1, 1)) + + result = slerp(pass_1, pass_2, ratios) + result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) + return result.to(orig_dtype) + + +def lanczos(samples, width, height): + images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] + images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] + images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] + result = torch.stack(images) + return result.to(samples.device, samples.dtype) + + +def adaptive_resize(samples, width, height, upscale_method, crop): + if crop == "center": + old_width = samples.shape[3] + old_height = samples.shape[2] + old_aspect = old_width / old_height + new_aspect = width / height + x = 0 + y = 0 + if old_aspect > new_aspect: + x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) + elif old_aspect < new_aspect: + y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) + s = samples[:, :, y:old_height - y, x:old_width - x] + else: + s = samples + + if upscale_method == "bislerp": + return bislerp(s, width, height) + elif upscale_method == "lanczos": + return lanczos(s, width, height) + else: + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) diff --git a/ldm_patched/controlnet/cldm.py b/backend/nn/cnets/cldm.py similarity index 60% rename from ldm_patched/controlnet/cldm.py rename to backend/nn/cnets/cldm.py index 82265ef9..6077ea31 100644 --- a/ldm_patched/controlnet/cldm.py +++ b/backend/nn/cnets/cldm.py @@ -1,60 +1,42 @@ -#taken from: https://github.com/lllyasviel/ControlNet -#and modified - import torch -import torch as th import torch.nn as nn -from ldm_patched.ldm.modules.diffusionmodules.util import ( - zero_module, - timestep_embedding, -) +from backend.nn.unet import timestep_embedding, exists, conv_nd, SpatialTransformer, TimestepEmbedSequential, ResBlock, Downsample -from ldm_patched.ldm.modules.attention import SpatialTransformer -from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample -from ldm_patched.ldm.util import exists -import ldm_patched.modules.ops - -class ControlledUnetModel(UNetModel): - #implemented in the ldm unet - pass class ControlNet(nn.Module): def __init__( - self, - image_size, - in_channels, - model_channels, - hint_channels, - num_res_blocks, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - num_classes=None, - use_checkpoint=False, - dtype=torch.float32, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - disable_self_attentions=None, - num_attention_blocks=None, - disable_middle_self_attn=False, - use_linear_in_transformer=False, - adm_in_channels=None, - transformer_depth_middle=None, - transformer_depth_output=None, - device=None, - operations=ldm_patched.modules.ops.disable_weight_init, - **kwargs, + self, + in_channels, + model_channels, + hint_channels, + num_res_blocks, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + dtype=torch.float32, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, + transformer_depth=1, + context_dim=None, + n_embed=None, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + adm_in_channels=None, + transformer_depth_middle=None, + transformer_depth_output=None, + device=None, + **kwargs, ): super().__init__() assert use_spatial_transformer == True, "use_spatial_transformer has to be true" @@ -77,7 +59,6 @@ class ControlNet(nn.Module): assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' self.dims = dims - self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels @@ -111,9 +92,9 @@ class ControlNet(nn.Module): time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( - operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), + nn.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), nn.SiLU(), - operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), + nn.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), ) if self.num_classes is not None: @@ -126,9 +107,9 @@ class ControlNet(nn.Module): assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( - operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), + nn.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), nn.SiLU(), - operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), + nn.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), ) ) else: @@ -137,28 +118,28 @@ class ControlNet(nn.Module): self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device) + nn.Conv2d(in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device) ) ] ) - self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)]) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, dtype=self.dtype, device=device)]) self.input_hint_block = TimestepEmbedSequential( - operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device), - nn.SiLU(), - operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device), - nn.SiLU(), - operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device), - nn.SiLU(), - operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device), - nn.SiLU(), - operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device), - nn.SiLU(), - operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device), - nn.SiLU(), - operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device), - nn.SiLU(), - operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device) + conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device), + nn.SiLU(), + conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device), + nn.SiLU(), + conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device), + nn.SiLU(), + conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device), + nn.SiLU(), + conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device) ) self._feature_size = model_channels @@ -178,7 +159,6 @@ class ControlNet(nn.Module): use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, device=device, - operations=operations, ) ] ch = mult * model_channels @@ -189,9 +169,7 @@ class ControlNet(nn.Module): else: num_heads = ch // num_head_channels dim_head = num_head_channels - if legacy: - #num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] else: @@ -202,11 +180,11 @@ class ControlNet(nn.Module): SpatialTransformer( ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) + self.zero_convs.append(self.make_zero_conv(ch, dtype=self.dtype, device=device)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: @@ -224,17 +202,16 @@ class ControlNet(nn.Module): down=True, dtype=self.dtype, device=device, - operations=operations ) if resblock_updown else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations + ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device ) ) ) ch = out_ch input_block_chans.append(ch) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) + self.zero_convs.append(self.make_zero_conv(ch, dtype=self.dtype, device=device)) ds *= 2 self._feature_size += ch @@ -243,9 +220,7 @@ class ControlNet(nn.Module): else: num_heads = ch // num_head_channels dim_head = num_head_channels - if legacy: - #num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + mid_block = [ ResBlock( ch, @@ -256,31 +231,30 @@ class ControlNet(nn.Module): use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, device=device, - operations=operations )] if transformer_depth_middle >= 0: - mid_block += [SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype, - device=device, - operations=operations - )] + mid_block += [ + SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, + )] self.middle_block = TimestepEmbedSequential(*mid_block) - self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) + self.middle_block_out = self.make_zero_conv(ch, dtype=self.dtype, device=device) self._feature_size += ch - def make_zero_conv(self, channels, operations=None, dtype=None, device=None): - return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) + def make_zero_conv(self, channels, dtype=None, device=None): + return TimestepEmbedSequential(conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) def forward(self, x, hint, timesteps, context, y=None, **kwargs): t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) @@ -309,4 +283,3 @@ class ControlNet(nn.Module): outs.append(self.middle_block_out(h, emb, context)) return outs - diff --git a/ldm_patched/t2ia/adapter.py b/backend/nn/cnets/t2i_adapter.py similarity index 97% rename from ldm_patched/t2ia/adapter.py rename to backend/nn/cnets/t2i_adapter.py index e9a606b1..7b9e8599 100644 --- a/ldm_patched/t2ia/adapter.py +++ b/backend/nn/cnets/t2i_adapter.py @@ -1,6 +1,6 @@ -#taken from https://github.com/TencentARC/T2I-Adapter import torch import torch.nn as nn + from collections import OrderedDict @@ -274,9 +274,9 @@ class Adapter_light(nn.Module): for i in range(len(channels)): if i == 0: - self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False)) + self.body.append(extractor(in_c=cin, inter_c=channels[i] // 4, out_c=channels[i], nums_rb=nums_rb, down=False)) else: - self.body.append(extractor(in_c=channels[i-1], inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=True)) + self.body.append(extractor(in_c=channels[i - 1], inter_c=channels[i] // 4, out_c=channels[i], nums_rb=nums_rb, down=True)) self.body = nn.ModuleList(self.body) def forward(self, x): diff --git a/backend/nn/unet.py b/backend/nn/unet.py index 41c1037e..954db913 100644 --- a/backend/nn/unet.py +++ b/backend/nn/unet.py @@ -655,12 +655,32 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin): device = unet_initial_device self.legacy_config = dict( - num_res_blocks=num_res_blocks, - channel_mult=channel_mult, - transformer_depth=transformer_depth, - transformer_depth_output=transformer_depth_output, - transformer_depth_middle=transformer_depth_middle, + in_channels=in_channels, + out_channels=out_channels, model_channels=model_channels, + num_res_blocks=num_res_blocks, + dropout=dropout, + channel_mult=channel_mult, + conv_resample=conv_resample, + dims=dims, + num_classes=num_classes, + dtype=dtype, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_spatial_transformer=use_spatial_transformer, + transformer_depth=transformer_depth, + context_dim=context_dim, + disable_self_attentions=disable_self_attentions, + num_attention_blocks=num_attention_blocks, + disable_middle_self_attn=disable_middle_self_attn, + use_linear_in_transformer=use_linear_in_transformer, + adm_in_channels=adm_in_channels, + transformer_depth_middle=transformer_depth_middle, + transformer_depth_output=transformer_depth_output, + device=device, ) if context_dim is not None: diff --git a/backend/operations.py b/backend/operations.py index 4c6adc29..8a40fa09 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -150,11 +150,13 @@ class ForgeOperationsWithManualCast(ForgeOperations): @contextlib.contextmanager -def using_forge_operations(parameters_manual_cast=False): - operations = ForgeOperations +def using_forge_operations(parameters_manual_cast=False, operations=None): - if parameters_manual_cast: - operations = ForgeOperationsWithManualCast + if operations is None: + operations = ForgeOperations + + if parameters_manual_cast: + operations = ForgeOperationsWithManualCast op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm'] backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} diff --git a/ldm_patched/modules/controlnet.py b/backend/patcher/controlnet.py similarity index 68% rename from ldm_patched/modules/controlnet.py rename to backend/patcher/controlnet.py index df88f49d..b1ef2993 100644 --- a/ldm_patched/modules/controlnet.py +++ b/backend/patcher/controlnet.py @@ -1,47 +1,14 @@ -# 1st edit by https://github.com/comfyanonymous/ComfyUI -# 2nd edit by Forge Official - - import torch import math -import os -import ldm_patched.modules.utils -import ldm_patched.modules.model_management -import ldm_patched.modules.model_detection -import ldm_patched.modules.model_patcher -import ldm_patched.modules.ops -import ldm_patched.controlnet.cldm -import ldm_patched.t2ia.adapter - -from ldm_patched.modules.ops import main_stream_worker - - -def broadcast_image_to(tensor, target_batch_size, batched_number): - current_batch_size = tensor.shape[0] - #print(current_batch_size, target_batch_size) - if current_batch_size == 1: - return tensor - - per_batch = target_batch_size // batched_number - tensor = tensor[:per_batch] - - if per_batch > tensor.shape[0]: - tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) - - current_batch_size = tensor.shape[0] - if current_batch_size == target_batch_size: - return tensor - else: - return torch.cat([tensor] * batched_number, dim=0) - - -def get_at(array, index, default=None): - return array[index] if 0 <= index < len(array) else default +from backend.misc import image_resize +from backend import memory_management, state_dict, utils +from backend.nn.cnets import cldm, t2i_adapter +from backend.patcher.base import ModelPatcher +from backend.operations import using_forge_operations, ForgeOperationsWithManualCast, main_stream_worker, weights_manual_cast def compute_controlnet_weighting(control, cnet): - positive_advanced_weighting = getattr(cnet, 'positive_advanced_weighting', None) negative_advanced_weighting = getattr(cnet, 'negative_advanced_weighting', None) advanced_frame_weighting = getattr(cnet, 'advanced_frame_weighting', None) @@ -108,6 +75,28 @@ def compute_controlnet_weighting(control, cnet): return control +def broadcast_image_to(tensor, target_batch_size, batched_number): + current_batch_size = tensor.shape[0] + if current_batch_size == 1: + return tensor + + per_batch = target_batch_size // batched_number + tensor = tensor[:per_batch] + + if per_batch > tensor.shape[0]: + tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) + + current_batch_size = tensor.shape[0] + if current_batch_size == target_batch_size: + return tensor + else: + return torch.cat([tensor] * batched_number, dim=0) + + +def get_at(array, index, default=None): + return array[index] if 0 <= index < len(array) else default + + class ControlBase: def __init__(self, device=None): self.cond_hint_original = None @@ -119,7 +108,7 @@ class ControlBase: self.transformer_options = {} if device is None: - device = ldm_patched.modules.model_management.get_torch_device() + device = memory_management.get_torch_device() self.device = device self.previous_controlnet = None @@ -164,7 +153,7 @@ class ControlBase: return 0 def control_merge(self, control_input, control_output, control_prev, output_dtype): - out = {'input':[], 'middle':[], 'output': []} + out = {'input': [], 'middle': [], 'output': []} if control_input is not None: for i in range(len(control_input)): @@ -214,12 +203,13 @@ class ControlBase: o[i] += prev_val return out + class ControlNet(ControlBase): def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model self.load_device = load_device - self.control_model_wrapped = ldm_patched.modules.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=ldm_patched.modules.model_management.unet_offload_device()) + self.control_model_wrapped = ModelPatcher(self.control_model, load_device=load_device, offload_device=memory_management.unet_offload_device()) self.global_average_pooling = global_average_pooling self.model_sampling_current = None self.manual_cast_dtype = manual_cast_dtype @@ -250,7 +240,7 @@ class ControlNet(ControlBase): if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype) + self.cond_hint = image_resize.adaptive_resize(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) @@ -291,11 +281,10 @@ class ControlNet(ControlBase): self.model_sampling_current = None super().cleanup() -class ControlLoraOps: + +class ControlLoraOps(ForgeOperationsWithManualCast): class Linear(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: super().__init__() self.in_features = in_features self.out_features = out_features @@ -305,7 +294,7 @@ class ControlLoraOps: self.bias = None def forward(self, input): - weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input) + weight, bias, signal = weights_manual_cast(self, input) with main_stream_worker(weight, bias, signal): if self.up is not None: return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) @@ -314,18 +303,18 @@ class ControlLoraOps: class Conv2d(torch.nn.Module): def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode='zeros', - device=None, - dtype=None + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None ): super().__init__() self.in_channels = in_channels @@ -344,9 +333,8 @@ class ControlLoraOps: self.up = None self.down = None - def forward(self, input): - weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input) + weight, bias, signal = weights_manual_cast(self, input) with main_stream_worker(weight, bias, signal): if self.up is not None: return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) @@ -362,37 +350,30 @@ class ControlLora(ControlNet): def pre_run(self, model, percent_to_timestep_function): super().pre_run(model, percent_to_timestep_function) - controlnet_config = model.model_config.unet_config.copy() + controlnet_config = model.diffusion_model.legacy_config.copy() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] - self.manual_cast_dtype = model.manual_cast_dtype - dtype = model.get_dtype() - if self.manual_cast_dtype is None: - class control_lora_ops(ControlLoraOps, ldm_patched.modules.ops.disable_weight_init): - pass - else: - class control_lora_ops(ControlLoraOps, ldm_patched.modules.ops.manual_cast): - pass - dtype = self.manual_cast_dtype + controlnet_config["dtype"] = dtype = model.storage_dtype - controlnet_config["operations"] = control_lora_ops - controlnet_config["dtype"] = dtype - self.control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config) - self.control_model.to(ldm_patched.modules.model_management.get_torch_device()) + self.manual_cast_dtype = model.computation_dtype + + with using_forge_operations(operations=ControlLoraOps): + self.control_model = cldm.ControlNet(**controlnet_config) + + self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype) diffusion_model = model.diffusion_model sd = diffusion_model.state_dict() - cm = self.control_model.state_dict() for k in sd: weight = sd[k] try: - ldm_patched.modules.utils.set_attr(self.control_model, k, weight) + utils.set_attr(self.control_model, k, weight) except: pass for k in self.control_weights: if k not in {"lora_controlnet"}: - ldm_patched.modules.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(ldm_patched.modules.model_management.get_torch_device())) + utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(memory_management.get_torch_device())) def copy(self): c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) @@ -409,117 +390,8 @@ class ControlLora(ControlNet): return out def inference_memory_requirements(self, dtype): - return ldm_patched.modules.utils.calculate_parameters(self.control_weights) * ldm_patched.modules.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) + return utils.calculate_parameters(self.control_weights) * memory_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) -def load_controlnet(ckpt_path, model=None): - controlnet_data = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True) - if "lora_controlnet" in controlnet_data: - return ControlLora(controlnet_data) - - controlnet_config = None - if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format - unet_dtype = ldm_patched.modules.model_management.unet_dtype() - controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype) - diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config) - diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" - diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" - - count = 0 - loop = True - while loop: - suffix = [".weight", ".bias"] - for s in suffix: - k_in = "controlnet_down_blocks.{}{}".format(count, s) - k_out = "zero_convs.{}.0{}".format(count, s) - if k_in not in controlnet_data: - loop = False - break - diffusers_keys[k_in] = k_out - count += 1 - - count = 0 - loop = True - while loop: - suffix = [".weight", ".bias"] - for s in suffix: - if count == 0: - k_in = "controlnet_cond_embedding.conv_in{}".format(s) - else: - k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) - k_out = "input_hint_block.{}{}".format(count * 2, s) - if k_in not in controlnet_data: - k_in = "controlnet_cond_embedding.conv_out{}".format(s) - loop = False - diffusers_keys[k_in] = k_out - count += 1 - - new_sd = {} - for k in diffusers_keys: - if k in controlnet_data: - new_sd[diffusers_keys[k]] = controlnet_data.pop(k) - - leftover_keys = controlnet_data.keys() - if len(leftover_keys) > 0: - print("leftover keys:", leftover_keys) - controlnet_data = new_sd - - pth_key = 'control_model.zero_convs.0.0.weight' - pth = False - key = 'zero_convs.0.0.weight' - if pth_key in controlnet_data: - pth = True - key = pth_key - prefix = "control_model." - elif key in controlnet_data: - prefix = "" - else: - net = load_t2i_adapter(controlnet_data) - if net is None: - print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) - return net - - if controlnet_config is None: - unet_dtype = ldm_patched.modules.model_management.unet_dtype() - controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config - load_device = ldm_patched.modules.model_management.get_torch_device() - manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device) - if manual_cast_dtype is not None: - controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast - controlnet_config.pop("out_channels") - controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] - control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config) - - if pth: - if 'difference' in controlnet_data: - if model is not None: - ldm_patched.modules.model_management.load_models_gpu([model]) - model_sd = model.model_state_dict() - for x in controlnet_data: - c_m = "control_model." - if x.startswith(c_m): - sd_key = "diffusion_model.{}".format(x[len(c_m):]) - if sd_key in model_sd: - cd = controlnet_data[x] - cd += model_sd[sd_key].type(cd.dtype).to(cd.device) - else: - print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") - - class WeightsLoader(torch.nn.Module): - pass - w = WeightsLoader() - w.control_model = control_model - missing, unexpected = w.load_state_dict(controlnet_data, strict=False) - else: - missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) - print(missing, unexpected) - - global_average_pooling = False - filename = os.path.splitext(ckpt_path)[0] - if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling - global_average_pooling = True - - control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) - return control class T2IAdapter(ControlBase): def __init__(self, t2i_model, channels_in, device=None): @@ -557,7 +429,7 @@ class T2IAdapter(ControlBase): self.control_input = None self.cond_hint = None width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) - self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float() + self.cond_hint = image_resize.adaptive_resize(self.cond_hint_original, width, height, 'nearest-exact', "center").float() if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) if x_noisy.shape[0] != self.cond_hint.shape[0]: @@ -591,22 +463,23 @@ class T2IAdapter(ControlBase): self.copy_to(c) return c + def load_t2i_adapter(t2i_data): if 'adapter' in t2i_data: t2i_data = t2i_data['adapter'] - if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format + if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: # diffusers format prefix_replace = {} for i in range(4): for j in range(2): prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j) prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2) prefix_replace["adapter."] = "" - t2i_data = ldm_patched.modules.utils.state_dict_prefix_replace(t2i_data, prefix_replace) + t2i_data = state_dict.state_dict_prefix_replace(t2i_data, prefix_replace) keys = t2i_data.keys() if "body.0.in_conv.weight" in keys: cin = t2i_data['body.0.in_conv.weight'].shape[1] - model_ad = ldm_patched.t2ia.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) + model_ad = t2i_adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) elif 'conv_in.weight' in keys: cin = t2i_data['conv_in.weight'].shape[1] channel = t2i_data['conv_in.weight'].shape[0] @@ -618,9 +491,10 @@ def load_t2i_adapter(t2i_data): xl = False if cin == 256 or cin == 768: xl = True - model_ad = ldm_patched.t2ia.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) + model_ad = t2i_adapter.Adapter(cin=cin, channels=[channel, channel * 2, channel * 4, channel * 4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) else: return None + missing, unexpected = model_ad.load_state_dict(t2i_data) if len(missing) > 0: print("t2i missing", missing) diff --git a/backend/utils.py b/backend/utils.py index 57b9276b..bd3d751d 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -28,3 +28,11 @@ def get_attr(obj, attr): for name in attrs: obj = getattr(obj, name) return obj + + +def calculate_parameters(sd, prefix=""): + params = 0 + for k in sd.keys(): + if k.startswith(prefix): + params += sd[k].nelement() + return params diff --git a/extensions-builtin/sd_forge_multidiffusion/lib_multidiffusion/tiled_diffusion.py b/extensions-builtin/sd_forge_multidiffusion/lib_multidiffusion/tiled_diffusion.py index 39890404..5ad5b860 100644 --- a/extensions-builtin/sd_forge_multidiffusion/lib_multidiffusion/tiled_diffusion.py +++ b/extensions-builtin/sd_forge_multidiffusion/lib_multidiffusion/tiled_diffusion.py @@ -14,7 +14,7 @@ from ldm_patched.modules.model_base import BaseModel from typing import List, Union, Tuple, Dict from ldm_patched.contrib.external import ImageScale import ldm_patched.modules.utils -from ldm_patched.modules.controlnet import ControlNet, T2IAdapter +from backend.patcher.controlnet import ControlNet, T2IAdapter opt_C = 4 opt_f = 8 diff --git a/ldm_patched/contrib/external.py b/ldm_patched/contrib/external.py index 86f533a6..efb29726 100644 --- a/ldm_patched/contrib/external.py +++ b/ldm_patched/contrib/external.py @@ -26,7 +26,7 @@ import ldm_patched.modules.samplers import ldm_patched.modules.sample import ldm_patched.modules.sd import ldm_patched.modules.utils -import ldm_patched.modules.controlnet +# import ldm_patched.modules.controlnet import ldm_patched.modules.clip_vision diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index 38377e83..f2d0b5eb 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -22,7 +22,7 @@ from . import sdxl_clip import ldm_patched.modules.model_patcher import ldm_patched.modules.lora -import ldm_patched.t2ia.adapter +# import ldm_patched.t2ia.adapter import ldm_patched.modules.supported_models_base import ldm_patched.taesd.taesd diff --git a/modules_forge/supported_controlnet.py b/modules_forge/supported_controlnet.py index 1490259e..49b04c19 100644 --- a/modules_forge/supported_controlnet.py +++ b/modules_forge/supported_controlnet.py @@ -1,9 +1,10 @@ import os import torch import ldm_patched.modules.utils -import ldm_patched.controlnet -from ldm_patched.modules.controlnet import ControlLora, ControlNet, load_t2i_adapter +from backend.operations import using_forge_operations +from backend.nn.cnets import cldm +from backend.patcher.controlnet import ControlLora, ControlNet, load_t2i_adapter from modules_forge.controlnet import apply_controlnet_advanced from modules_forge.shared import add_supported_control_model @@ -43,8 +44,7 @@ class ControlNetPatcher(ControlModelPatcher): controlnet_config = None if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format unet_dtype = ldm_patched.modules.model_management.unet_dtype() - controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, - unet_dtype) + controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype) diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" @@ -105,15 +105,16 @@ class ControlNetPatcher(ControlModelPatcher): if controlnet_config is None: unet_dtype = ldm_patched.modules.model_management.unet_dtype() - controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, - unet_dtype, True).unet_config + controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config + load_device = ldm_patched.modules.model_management.get_torch_device() manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device) - if manual_cast_dtype is not None: - controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast + controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] - control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config) + + with using_forge_operations(parameters_manual_cast=manual_cast_dtype is not None): + control_model = cldm.ControlNet(**controlnet_config) if pth: if 'difference' in controlnet_data: @@ -136,8 +137,7 @@ class ControlNetPatcher(ControlModelPatcher): # TODO: smarter way of enabling global_average_pooling global_average_pooling = True - control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, - manual_cast_dtype=manual_cast_dtype) + control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return ControlNetPatcher(control) def __init__(self, model_patcher):