control rework

This commit is contained in:
layerdiffusion
2024-08-02 21:29:51 -07:00
parent 9a449a1d98
commit e722991752
11 changed files with 314 additions and 324 deletions

View File

@@ -0,0 +1,113 @@
import torch
import numpy as np
from PIL import Image
def bislerp(samples, width, height):
def slerp(b1, b2, r):
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
c = b1.shape[-1]
# norms
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
# normalize
b1_normalized = b1 / b1_norms
b2_normalized = b2 / b2_norms
# zero when norms are zero
b1_normalized[b1_norms.expand(-1, c) == 0.0] = 0.0
b2_normalized[b2_norms.expand(-1, c) == 0.0] = 0.0
# slerp
dot = (b1_normalized * b2_normalized).sum(1)
omega = torch.acos(dot)
so = torch.sin(omega)
# technically not mathematically correct, but more pleasing?
res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(1) * b1_normalized + (torch.sin(r.squeeze(1) * omega) / so).unsqueeze(1) * b2_normalized
res *= (b1_norms * (1.0 - r) + b2_norms * r).expand(-1, c)
# edge cases for same or polar opposites
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1]
return res
def generate_bilinear_data(length_old, length_new, device):
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1))
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) + 1
coords_2[:, :, :, -1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2
orig_dtype = samples.dtype
samples = samples.float()
n, c, h, w = samples.shape
h_new, w_new = (height, width)
# linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
coords_1 = coords_1.expand((n, c, h, -1))
coords_2 = coords_2.expand((n, c, h, -1))
ratios = ratios.expand((n, 1, h, -1))
pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c))
pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c))
ratios = ratios.movedim(1, -1).reshape((-1, 1))
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
# linear h
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new))
pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c))
pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c))
ratios = ratios.movedim(1, -1).reshape((-1, 1))
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result.to(orig_dtype)
def lanczos(samples, width, height):
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
result = torch.stack(images)
return result.to(samples.device, samples.dtype)
def adaptive_resize(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:, :, y:old_height - y, x:old_width - x]
else:
s = samples
if upscale_method == "bislerp":
return bislerp(s, width, height)
elif upscale_method == "lanczos":
return lanczos(s, width, height)
else:
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)

View File

@@ -1,60 +1,42 @@
#taken from: https://github.com/lllyasviel/ControlNet
#and modified
import torch
import torch as th
import torch.nn as nn
from ldm_patched.ldm.modules.diffusionmodules.util import (
zero_module,
timestep_embedding,
)
from backend.nn.unet import timestep_embedding, exists, conv_nd, SpatialTransformer, TimestepEmbedSequential, ResBlock, Downsample
from ldm_patched.ldm.modules.attention import SpatialTransformer
from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ldm_patched.ldm.util import exists
import ldm_patched.modules.ops
class ControlledUnetModel(UNetModel):
#implemented in the ldm unet
pass
class ControlNet(nn.Module):
def __init__(
self,
image_size,
in_channels,
model_channels,
hint_channels,
num_res_blocks,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
dtype=torch.float32,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
operations=ldm_patched.modules.ops.disable_weight_init,
**kwargs,
self,
in_channels,
model_channels,
hint_channels,
num_res_blocks,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
dtype=torch.float32,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False,
transformer_depth=1,
context_dim=None,
n_embed=None,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
**kwargs,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@@ -77,7 +59,6 @@ class ControlNet(nn.Module):
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.dims = dims
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
@@ -111,9 +92,9 @@ class ControlNet(nn.Module):
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
nn.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
)
if self.num_classes is not None:
@@ -126,9 +107,9 @@ class ControlNet(nn.Module):
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
nn.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
)
)
else:
@@ -137,28 +118,28 @@ class ControlNet(nn.Module):
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
nn.Conv2d(in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, dtype=self.dtype, device=device)])
self.input_hint_block = TimestepEmbedSequential(
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
)
self._feature_size = model_channels
@@ -178,7 +159,6 @@ class ControlNet(nn.Module):
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations,
)
]
ch = mult * model_channels
@@ -189,9 +169,7 @@ class ControlNet(nn.Module):
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
@@ -202,11 +180,11 @@ class ControlNet(nn.Module):
SpatialTransformer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
self.zero_convs.append(self.make_zero_conv(ch, dtype=self.dtype, device=device))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
@@ -224,17 +202,16 @@ class ControlNet(nn.Module):
down=True,
dtype=self.dtype,
device=device,
operations=operations
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device
)
)
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
self.zero_convs.append(self.make_zero_conv(ch, dtype=self.dtype, device=device))
ds *= 2
self._feature_size += ch
@@ -243,9 +220,7 @@ class ControlNet(nn.Module):
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
mid_block = [
ResBlock(
ch,
@@ -256,31 +231,30 @@ class ControlNet(nn.Module):
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
)]
if transformer_depth_middle >= 0:
mid_block += [SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
)]
mid_block += [
SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self.middle_block_out = self.make_zero_conv(ch, dtype=self.dtype, device=device)
self._feature_size += ch
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
def make_zero_conv(self, channels, dtype=None, device=None):
return TimestepEmbedSequential(conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
@@ -309,4 +283,3 @@ class ControlNet(nn.Module):
outs.append(self.middle_block_out(h, emb, context))
return outs

View File

@@ -1,6 +1,6 @@
#taken from https://github.com/TencentARC/T2I-Adapter
import torch
import torch.nn as nn
from collections import OrderedDict
@@ -274,9 +274,9 @@ class Adapter_light(nn.Module):
for i in range(len(channels)):
if i == 0:
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
self.body.append(extractor(in_c=cin, inter_c=channels[i] // 4, out_c=channels[i], nums_rb=nums_rb, down=False))
else:
self.body.append(extractor(in_c=channels[i-1], inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=True))
self.body.append(extractor(in_c=channels[i - 1], inter_c=channels[i] // 4, out_c=channels[i], nums_rb=nums_rb, down=True))
self.body = nn.ModuleList(self.body)
def forward(self, x):

View File

@@ -655,12 +655,32 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
device = unet_initial_device
self.legacy_config = dict(
num_res_blocks=num_res_blocks,
channel_mult=channel_mult,
transformer_depth=transformer_depth,
transformer_depth_output=transformer_depth_output,
transformer_depth_middle=transformer_depth_middle,
in_channels=in_channels,
out_channels=out_channels,
model_channels=model_channels,
num_res_blocks=num_res_blocks,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
num_classes=num_classes,
dtype=dtype,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
use_spatial_transformer=use_spatial_transformer,
transformer_depth=transformer_depth,
context_dim=context_dim,
disable_self_attentions=disable_self_attentions,
num_attention_blocks=num_attention_blocks,
disable_middle_self_attn=disable_middle_self_attn,
use_linear_in_transformer=use_linear_in_transformer,
adm_in_channels=adm_in_channels,
transformer_depth_middle=transformer_depth_middle,
transformer_depth_output=transformer_depth_output,
device=device,
)
if context_dim is not None:

View File

@@ -150,11 +150,13 @@ class ForgeOperationsWithManualCast(ForgeOperations):
@contextlib.contextmanager
def using_forge_operations(parameters_manual_cast=False):
operations = ForgeOperations
def using_forge_operations(parameters_manual_cast=False, operations=None):
if parameters_manual_cast:
operations = ForgeOperationsWithManualCast
if operations is None:
operations = ForgeOperations
if parameters_manual_cast:
operations = ForgeOperationsWithManualCast
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}

View File

@@ -1,47 +1,14 @@
# 1st edit by https://github.com/comfyanonymous/ComfyUI
# 2nd edit by Forge Official
import torch
import math
import os
import ldm_patched.modules.utils
import ldm_patched.modules.model_management
import ldm_patched.modules.model_detection
import ldm_patched.modules.model_patcher
import ldm_patched.modules.ops
import ldm_patched.controlnet.cldm
import ldm_patched.t2ia.adapter
from ldm_patched.modules.ops import main_stream_worker
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
per_batch = target_batch_size // batched_number
tensor = tensor[:per_batch]
if per_batch > tensor.shape[0]:
tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
current_batch_size = tensor.shape[0]
if current_batch_size == target_batch_size:
return tensor
else:
return torch.cat([tensor] * batched_number, dim=0)
def get_at(array, index, default=None):
return array[index] if 0 <= index < len(array) else default
from backend.misc import image_resize
from backend import memory_management, state_dict, utils
from backend.nn.cnets import cldm, t2i_adapter
from backend.patcher.base import ModelPatcher
from backend.operations import using_forge_operations, ForgeOperationsWithManualCast, main_stream_worker, weights_manual_cast
def compute_controlnet_weighting(control, cnet):
positive_advanced_weighting = getattr(cnet, 'positive_advanced_weighting', None)
negative_advanced_weighting = getattr(cnet, 'negative_advanced_weighting', None)
advanced_frame_weighting = getattr(cnet, 'advanced_frame_weighting', None)
@@ -108,6 +75,28 @@ def compute_controlnet_weighting(control, cnet):
return control
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
if current_batch_size == 1:
return tensor
per_batch = target_batch_size // batched_number
tensor = tensor[:per_batch]
if per_batch > tensor.shape[0]:
tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
current_batch_size = tensor.shape[0]
if current_batch_size == target_batch_size:
return tensor
else:
return torch.cat([tensor] * batched_number, dim=0)
def get_at(array, index, default=None):
return array[index] if 0 <= index < len(array) else default
class ControlBase:
def __init__(self, device=None):
self.cond_hint_original = None
@@ -119,7 +108,7 @@ class ControlBase:
self.transformer_options = {}
if device is None:
device = ldm_patched.modules.model_management.get_torch_device()
device = memory_management.get_torch_device()
self.device = device
self.previous_controlnet = None
@@ -164,7 +153,7 @@ class ControlBase:
return 0
def control_merge(self, control_input, control_output, control_prev, output_dtype):
out = {'input':[], 'middle':[], 'output': []}
out = {'input': [], 'middle': [], 'output': []}
if control_input is not None:
for i in range(len(control_input)):
@@ -214,12 +203,13 @@ class ControlBase:
o[i] += prev_val
return out
class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device)
self.control_model = control_model
self.load_device = load_device
self.control_model_wrapped = ldm_patched.modules.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=ldm_patched.modules.model_management.unet_offload_device())
self.control_model_wrapped = ModelPatcher(self.control_model, load_device=load_device, offload_device=memory_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype
@@ -250,7 +240,7 @@ class ControlNet(ControlBase):
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype)
self.cond_hint = image_resize.adaptive_resize(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
@@ -291,11 +281,10 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
super().cleanup()
class ControlLoraOps:
class ControlLoraOps(ForgeOperationsWithManualCast):
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
@@ -305,7 +294,7 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
weight, bias, signal = weights_manual_cast(self, input)
with main_stream_worker(weight, bias, signal):
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
@@ -314,18 +303,18 @@ class ControlLoraOps:
class Conv2d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
device=None,
dtype=None
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
device=None,
dtype=None
):
super().__init__()
self.in_channels = in_channels
@@ -344,9 +333,8 @@ class ControlLoraOps:
self.up = None
self.down = None
def forward(self, input):
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
weight, bias, signal = weights_manual_cast(self, input)
with main_stream_worker(weight, bias, signal):
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
@@ -362,37 +350,30 @@ class ControlLora(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
controlnet_config = model.model_config.unet_config.copy()
controlnet_config = model.diffusion_model.legacy_config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
self.manual_cast_dtype = model.manual_cast_dtype
dtype = model.get_dtype()
if self.manual_cast_dtype is None:
class control_lora_ops(ControlLoraOps, ldm_patched.modules.ops.disable_weight_init):
pass
else:
class control_lora_ops(ControlLoraOps, ldm_patched.modules.ops.manual_cast):
pass
dtype = self.manual_cast_dtype
controlnet_config["dtype"] = dtype = model.storage_dtype
controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype
self.control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
self.control_model.to(ldm_patched.modules.model_management.get_torch_device())
self.manual_cast_dtype = model.computation_dtype
with using_forge_operations(operations=ControlLoraOps):
self.control_model = cldm.ControlNet(**controlnet_config)
self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype)
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
weight = sd[k]
try:
ldm_patched.modules.utils.set_attr(self.control_model, k, weight)
utils.set_attr(self.control_model, k, weight)
except:
pass
for k in self.control_weights:
if k not in {"lora_controlnet"}:
ldm_patched.modules.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(ldm_patched.modules.model_management.get_torch_device()))
utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(memory_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
@@ -409,117 +390,8 @@ class ControlLora(ControlNet):
return out
def inference_memory_requirements(self, dtype):
return ldm_patched.modules.utils.calculate_parameters(self.control_weights) * ldm_patched.modules.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
return utils.calculate_parameters(self.control_weights) * memory_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def load_controlnet(ckpt_path, model=None):
controlnet_data = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
count = 0
loop = True
while loop:
suffix = [".weight", ".bias"]
for s in suffix:
k_in = "controlnet_down_blocks.{}{}".format(count, s)
k_out = "zero_convs.{}.0{}".format(count, s)
if k_in not in controlnet_data:
loop = False
break
diffusers_keys[k_in] = k_out
count += 1
count = 0
loop = True
while loop:
suffix = [".weight", ".bias"]
for s in suffix:
if count == 0:
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
else:
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
k_out = "input_hint_block.{}{}".format(count * 2, s)
if k_in not in controlnet_data:
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
loop = False
diffusers_keys[k_in] = k_out
count += 1
new_sd = {}
for k in diffusers_keys:
if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0:
print("leftover keys:", leftover_keys)
controlnet_data = new_sd
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
key = 'zero_convs.0.0.weight'
if pth_key in controlnet_data:
pth = True
key = pth_key
prefix = "control_model."
elif key in controlnet_data:
prefix = ""
else:
net = load_t2i_adapter(controlnet_data)
if net is None:
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
return net
if controlnet_config is None:
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
load_device = ldm_patched.modules.model_management.get_torch_device()
manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
if pth:
if 'difference' in controlnet_data:
if model is not None:
ldm_patched.modules.model_management.load_models_gpu([model])
model_sd = model.model_state_dict()
for x in controlnet_data:
c_m = "control_model."
if x.startswith(c_m):
sd_key = "diffusion_model.{}".format(x[len(c_m):])
if sd_key in model_sd:
cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
else:
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
w.control_model = control_model
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
else:
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, device=None):
@@ -557,7 +429,7 @@ class T2IAdapter(ControlBase):
self.control_input = None
self.cond_hint = None
width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float()
self.cond_hint = image_resize.adaptive_resize(self.cond_hint_original, width, height, 'nearest-exact', "center").float()
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
@@ -591,22 +463,23 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def load_t2i_adapter(t2i_data):
if 'adapter' in t2i_data:
t2i_data = t2i_data['adapter']
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: # diffusers format
prefix_replace = {}
for i in range(4):
for j in range(2):
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
prefix_replace["adapter."] = ""
t2i_data = ldm_patched.modules.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
t2i_data = state_dict.state_dict_prefix_replace(t2i_data, prefix_replace)
keys = t2i_data.keys()
if "body.0.in_conv.weight" in keys:
cin = t2i_data['body.0.in_conv.weight'].shape[1]
model_ad = ldm_patched.t2ia.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
model_ad = t2i_adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
elif 'conv_in.weight' in keys:
cin = t2i_data['conv_in.weight'].shape[1]
channel = t2i_data['conv_in.weight'].shape[0]
@@ -618,9 +491,10 @@ def load_t2i_adapter(t2i_data):
xl = False
if cin == 256 or cin == 768:
xl = True
model_ad = ldm_patched.t2ia.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
model_ad = t2i_adapter.Adapter(cin=cin, channels=[channel, channel * 2, channel * 4, channel * 4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
else:
return None
missing, unexpected = model_ad.load_state_dict(t2i_data)
if len(missing) > 0:
print("t2i missing", missing)

View File

@@ -28,3 +28,11 @@ def get_attr(obj, attr):
for name in attrs:
obj = getattr(obj, name)
return obj
def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
params += sd[k].nelement()
return params

View File

@@ -14,7 +14,7 @@ from ldm_patched.modules.model_base import BaseModel
from typing import List, Union, Tuple, Dict
from ldm_patched.contrib.external import ImageScale
import ldm_patched.modules.utils
from ldm_patched.modules.controlnet import ControlNet, T2IAdapter
from backend.patcher.controlnet import ControlNet, T2IAdapter
opt_C = 4
opt_f = 8

View File

@@ -26,7 +26,7 @@ import ldm_patched.modules.samplers
import ldm_patched.modules.sample
import ldm_patched.modules.sd
import ldm_patched.modules.utils
import ldm_patched.modules.controlnet
# import ldm_patched.modules.controlnet
import ldm_patched.modules.clip_vision

View File

@@ -22,7 +22,7 @@ from . import sdxl_clip
import ldm_patched.modules.model_patcher
import ldm_patched.modules.lora
import ldm_patched.t2ia.adapter
# import ldm_patched.t2ia.adapter
import ldm_patched.modules.supported_models_base
import ldm_patched.taesd.taesd

View File

@@ -1,9 +1,10 @@
import os
import torch
import ldm_patched.modules.utils
import ldm_patched.controlnet
from ldm_patched.modules.controlnet import ControlLora, ControlNet, load_t2i_adapter
from backend.operations import using_forge_operations
from backend.nn.cnets import cldm
from backend.patcher.controlnet import ControlLora, ControlNet, load_t2i_adapter
from modules_forge.controlnet import apply_controlnet_advanced
from modules_forge.shared import add_supported_control_model
@@ -43,8 +44,7 @@ class ControlNetPatcher(ControlModelPatcher):
controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data,
unet_dtype)
controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@@ -105,15 +105,16 @@ class ControlNetPatcher(ControlModelPatcher):
if controlnet_config is None:
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix,
unet_dtype, True).unet_config
controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
load_device = ldm_patched.modules.model_management.get_torch_device()
manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
with using_forge_operations(parameters_manual_cast=manual_cast_dtype is not None):
control_model = cldm.ControlNet(**controlnet_config)
if pth:
if 'difference' in controlnet_data:
@@ -136,8 +137,7 @@ class ControlNetPatcher(ControlModelPatcher):
# TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device,
manual_cast_dtype=manual_cast_dtype)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return ControlNetPatcher(control)
def __init__(self, model_patcher):