mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-04 04:29:50 +00:00
control rework
This commit is contained in:
113
backend/misc/image_resize.py
Normal file
113
backend/misc/image_resize.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def bislerp(samples, width, height):
|
||||
def slerp(b1, b2, r):
|
||||
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||
|
||||
c = b1.shape[-1]
|
||||
|
||||
# norms
|
||||
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||
|
||||
# normalize
|
||||
b1_normalized = b1 / b1_norms
|
||||
b2_normalized = b2 / b2_norms
|
||||
|
||||
# zero when norms are zero
|
||||
b1_normalized[b1_norms.expand(-1, c) == 0.0] = 0.0
|
||||
b2_normalized[b2_norms.expand(-1, c) == 0.0] = 0.0
|
||||
|
||||
# slerp
|
||||
dot = (b1_normalized * b2_normalized).sum(1)
|
||||
omega = torch.acos(dot)
|
||||
so = torch.sin(omega)
|
||||
|
||||
# technically not mathematically correct, but more pleasing?
|
||||
res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(1) * b1_normalized + (torch.sin(r.squeeze(1) * omega) / so).unsqueeze(1) * b2_normalized
|
||||
res *= (b1_norms * (1.0 - r) + b2_norms * r).expand(-1, c)
|
||||
|
||||
# edge cases for same or polar opposites
|
||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1]
|
||||
return res
|
||||
|
||||
def generate_bilinear_data(length_old, length_new, device):
|
||||
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1))
|
||||
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||
ratios = coords_1 - coords_1.floor()
|
||||
coords_1 = coords_1.to(torch.int64)
|
||||
|
||||
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) + 1
|
||||
coords_2[:, :, :, -1] -= 1
|
||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||
coords_2 = coords_2.to(torch.int64)
|
||||
return ratios, coords_1, coords_2
|
||||
|
||||
orig_dtype = samples.dtype
|
||||
samples = samples.float()
|
||||
n, c, h, w = samples.shape
|
||||
h_new, w_new = (height, width)
|
||||
|
||||
# linear w
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
||||
coords_1 = coords_1.expand((n, c, h, -1))
|
||||
coords_2 = coords_2.expand((n, c, h, -1))
|
||||
ratios = ratios.expand((n, 1, h, -1))
|
||||
|
||||
pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c))
|
||||
pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1, 1))
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
||||
|
||||
# linear h
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
|
||||
coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
|
||||
coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
|
||||
ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new))
|
||||
|
||||
pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c))
|
||||
pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1, 1))
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||
return result.to(orig_dtype)
|
||||
|
||||
|
||||
def lanczos(samples, width, height):
|
||||
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||
result = torch.stack(images)
|
||||
return result.to(samples.device, samples.dtype)
|
||||
|
||||
|
||||
def adaptive_resize(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
old_height = samples.shape[2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
x = 0
|
||||
y = 0
|
||||
if old_aspect > new_aspect:
|
||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||
elif old_aspect < new_aspect:
|
||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||
s = samples[:, :, y:old_height - y, x:old_width - x]
|
||||
else:
|
||||
s = samples
|
||||
|
||||
if upscale_method == "bislerp":
|
||||
return bislerp(s, width, height)
|
||||
elif upscale_method == "lanczos":
|
||||
return lanczos(s, width, height)
|
||||
else:
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
@@ -1,60 +1,42 @@
|
||||
#taken from: https://github.com/lllyasviel/ControlNet
|
||||
#and modified
|
||||
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from ldm_patched.ldm.modules.diffusionmodules.util import (
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
)
|
||||
from backend.nn.unet import timestep_embedding, exists, conv_nd, SpatialTransformer, TimestepEmbedSequential, ResBlock, Downsample
|
||||
|
||||
from ldm_patched.ldm.modules.attention import SpatialTransformer
|
||||
from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
||||
from ldm_patched.ldm.util import exists
|
||||
import ldm_patched.modules.ops
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
#implemented in the ldm unet
|
||||
pass
|
||||
|
||||
class ControlNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
hint_channels,
|
||||
num_res_blocks,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
dtype=torch.float32,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
device=None,
|
||||
operations=ldm_patched.modules.ops.disable_weight_init,
|
||||
**kwargs,
|
||||
self,
|
||||
in_channels,
|
||||
model_channels,
|
||||
hint_channels,
|
||||
num_res_blocks,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
dtype=torch.float32,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False,
|
||||
transformer_depth=1,
|
||||
context_dim=None,
|
||||
n_embed=None,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
device=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||
@@ -77,7 +59,6 @@ class ControlNet(nn.Module):
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.dims = dims
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
|
||||
@@ -111,9 +92,9 @@ class ControlNet(nn.Module):
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
@@ -126,9 +107,9 @@ class ControlNet(nn.Module):
|
||||
assert adm_in_channels is not None
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -137,28 +118,28 @@ class ControlNet(nn.Module):
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||
nn.Conv2d(in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, dtype=self.dtype, device=device)])
|
||||
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||
conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||
)
|
||||
|
||||
self._feature_size = model_channels
|
||||
@@ -178,7 +159,6 @@ class ControlNet(nn.Module):
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
@@ -189,9 +169,7 @@ class ControlNet(nn.Module):
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
@@ -202,11 +180,11 @@ class ControlNet(nn.Module):
|
||||
SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
||||
self.zero_convs.append(self.make_zero_conv(ch, dtype=self.dtype, device=device))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
@@ -224,17 +202,16 @@ class ControlNet(nn.Module):
|
||||
down=True,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
||||
self.zero_convs.append(self.make_zero_conv(ch, dtype=self.dtype, device=device))
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
@@ -243,9 +220,7 @@ class ControlNet(nn.Module):
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
|
||||
mid_block = [
|
||||
ResBlock(
|
||||
ch,
|
||||
@@ -256,31 +231,30 @@ class ControlNet(nn.Module):
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)]
|
||||
if transformer_depth_middle >= 0:
|
||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)]
|
||||
mid_block += [
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)]
|
||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
||||
self.middle_block_out = self.make_zero_conv(ch, dtype=self.dtype, device=device)
|
||||
self._feature_size += ch
|
||||
|
||||
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
||||
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
||||
def make_zero_conv(self, channels, dtype=None, device=None):
|
||||
return TimestepEmbedSequential(conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
@@ -309,4 +283,3 @@ class ControlNet(nn.Module):
|
||||
outs.append(self.middle_block_out(h, emb, context))
|
||||
|
||||
return outs
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#taken from https://github.com/TencentARC/T2I-Adapter
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
@@ -274,9 +274,9 @@ class Adapter_light(nn.Module):
|
||||
|
||||
for i in range(len(channels)):
|
||||
if i == 0:
|
||||
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
|
||||
self.body.append(extractor(in_c=cin, inter_c=channels[i] // 4, out_c=channels[i], nums_rb=nums_rb, down=False))
|
||||
else:
|
||||
self.body.append(extractor(in_c=channels[i-1], inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=True))
|
||||
self.body.append(extractor(in_c=channels[i - 1], inter_c=channels[i] // 4, out_c=channels[i], nums_rb=nums_rb, down=True))
|
||||
self.body = nn.ModuleList(self.body)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -655,12 +655,32 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
|
||||
device = unet_initial_device
|
||||
|
||||
self.legacy_config = dict(
|
||||
num_res_blocks=num_res_blocks,
|
||||
channel_mult=channel_mult,
|
||||
transformer_depth=transformer_depth,
|
||||
transformer_depth_output=transformer_depth_output,
|
||||
transformer_depth_middle=transformer_depth_middle,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
model_channels=model_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
num_classes=num_classes,
|
||||
dtype=dtype,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
use_spatial_transformer=use_spatial_transformer,
|
||||
transformer_depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attentions=disable_self_attentions,
|
||||
num_attention_blocks=num_attention_blocks,
|
||||
disable_middle_self_attn=disable_middle_self_attn,
|
||||
use_linear_in_transformer=use_linear_in_transformer,
|
||||
adm_in_channels=adm_in_channels,
|
||||
transformer_depth_middle=transformer_depth_middle,
|
||||
transformer_depth_output=transformer_depth_output,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if context_dim is not None:
|
||||
|
||||
@@ -150,11 +150,13 @@ class ForgeOperationsWithManualCast(ForgeOperations):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def using_forge_operations(parameters_manual_cast=False):
|
||||
operations = ForgeOperations
|
||||
def using_forge_operations(parameters_manual_cast=False, operations=None):
|
||||
|
||||
if parameters_manual_cast:
|
||||
operations = ForgeOperationsWithManualCast
|
||||
if operations is None:
|
||||
operations = ForgeOperations
|
||||
|
||||
if parameters_manual_cast:
|
||||
operations = ForgeOperationsWithManualCast
|
||||
|
||||
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
|
||||
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
|
||||
|
||||
@@ -1,47 +1,14 @@
|
||||
# 1st edit by https://github.com/comfyanonymous/ComfyUI
|
||||
# 2nd edit by Forge Official
|
||||
|
||||
|
||||
import torch
|
||||
import math
|
||||
import os
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.modules.model_management
|
||||
import ldm_patched.modules.model_detection
|
||||
import ldm_patched.modules.model_patcher
|
||||
import ldm_patched.modules.ops
|
||||
|
||||
import ldm_patched.controlnet.cldm
|
||||
import ldm_patched.t2ia.adapter
|
||||
|
||||
from ldm_patched.modules.ops import main_stream_worker
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
if current_batch_size == 1:
|
||||
return tensor
|
||||
|
||||
per_batch = target_batch_size // batched_number
|
||||
tensor = tensor[:per_batch]
|
||||
|
||||
if per_batch > tensor.shape[0]:
|
||||
tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
|
||||
|
||||
current_batch_size = tensor.shape[0]
|
||||
if current_batch_size == target_batch_size:
|
||||
return tensor
|
||||
else:
|
||||
return torch.cat([tensor] * batched_number, dim=0)
|
||||
|
||||
|
||||
def get_at(array, index, default=None):
|
||||
return array[index] if 0 <= index < len(array) else default
|
||||
from backend.misc import image_resize
|
||||
from backend import memory_management, state_dict, utils
|
||||
from backend.nn.cnets import cldm, t2i_adapter
|
||||
from backend.patcher.base import ModelPatcher
|
||||
from backend.operations import using_forge_operations, ForgeOperationsWithManualCast, main_stream_worker, weights_manual_cast
|
||||
|
||||
|
||||
def compute_controlnet_weighting(control, cnet):
|
||||
|
||||
positive_advanced_weighting = getattr(cnet, 'positive_advanced_weighting', None)
|
||||
negative_advanced_weighting = getattr(cnet, 'negative_advanced_weighting', None)
|
||||
advanced_frame_weighting = getattr(cnet, 'advanced_frame_weighting', None)
|
||||
@@ -108,6 +75,28 @@ def compute_controlnet_weighting(control, cnet):
|
||||
return control
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
if current_batch_size == 1:
|
||||
return tensor
|
||||
|
||||
per_batch = target_batch_size // batched_number
|
||||
tensor = tensor[:per_batch]
|
||||
|
||||
if per_batch > tensor.shape[0]:
|
||||
tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
|
||||
|
||||
current_batch_size = tensor.shape[0]
|
||||
if current_batch_size == target_batch_size:
|
||||
return tensor
|
||||
else:
|
||||
return torch.cat([tensor] * batched_number, dim=0)
|
||||
|
||||
|
||||
def get_at(array, index, default=None):
|
||||
return array[index] if 0 <= index < len(array) else default
|
||||
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self, device=None):
|
||||
self.cond_hint_original = None
|
||||
@@ -119,7 +108,7 @@ class ControlBase:
|
||||
self.transformer_options = {}
|
||||
|
||||
if device is None:
|
||||
device = ldm_patched.modules.model_management.get_torch_device()
|
||||
device = memory_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
|
||||
@@ -164,7 +153,7 @@ class ControlBase:
|
||||
return 0
|
||||
|
||||
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||
out = {'input':[], 'middle':[], 'output': []}
|
||||
out = {'input': [], 'middle': [], 'output': []}
|
||||
|
||||
if control_input is not None:
|
||||
for i in range(len(control_input)):
|
||||
@@ -214,12 +203,13 @@ class ControlBase:
|
||||
o[i] += prev_val
|
||||
return out
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
self.control_model_wrapped = ldm_patched.modules.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=ldm_patched.modules.model_management.unet_offload_device())
|
||||
self.control_model_wrapped = ModelPatcher(self.control_model, load_device=load_device, offload_device=memory_management.unet_offload_device())
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.model_sampling_current = None
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
@@ -250,7 +240,7 @@ class ControlNet(ControlBase):
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype)
|
||||
self.cond_hint = image_resize.adaptive_resize(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
@@ -291,11 +281,10 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
class ControlLoraOps:
|
||||
|
||||
class ControlLoraOps(ForgeOperationsWithManualCast):
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
@@ -305,7 +294,7 @@ class ControlLoraOps:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
|
||||
weight, bias, signal = weights_manual_cast(self, input)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
@@ -314,18 +303,18 @@ class ControlLoraOps:
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
@@ -344,9 +333,8 @@ class ControlLoraOps:
|
||||
self.up = None
|
||||
self.down = None
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
|
||||
weight, bias, signal = weights_manual_cast(self, input)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
@@ -362,37 +350,30 @@ class ControlLora(ControlNet):
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
controlnet_config = model.model_config.unet_config.copy()
|
||||
controlnet_config = model.diffusion_model.legacy_config.copy()
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
||||
self.manual_cast_dtype = model.manual_cast_dtype
|
||||
dtype = model.get_dtype()
|
||||
if self.manual_cast_dtype is None:
|
||||
class control_lora_ops(ControlLoraOps, ldm_patched.modules.ops.disable_weight_init):
|
||||
pass
|
||||
else:
|
||||
class control_lora_ops(ControlLoraOps, ldm_patched.modules.ops.manual_cast):
|
||||
pass
|
||||
dtype = self.manual_cast_dtype
|
||||
controlnet_config["dtype"] = dtype = model.storage_dtype
|
||||
|
||||
controlnet_config["operations"] = control_lora_ops
|
||||
controlnet_config["dtype"] = dtype
|
||||
self.control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
|
||||
self.control_model.to(ldm_patched.modules.model_management.get_torch_device())
|
||||
self.manual_cast_dtype = model.computation_dtype
|
||||
|
||||
with using_forge_operations(operations=ControlLoraOps):
|
||||
self.control_model = cldm.ControlNet(**controlnet_config)
|
||||
|
||||
self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype)
|
||||
diffusion_model = model.diffusion_model
|
||||
sd = diffusion_model.state_dict()
|
||||
cm = self.control_model.state_dict()
|
||||
|
||||
for k in sd:
|
||||
weight = sd[k]
|
||||
try:
|
||||
ldm_patched.modules.utils.set_attr(self.control_model, k, weight)
|
||||
utils.set_attr(self.control_model, k, weight)
|
||||
except:
|
||||
pass
|
||||
|
||||
for k in self.control_weights:
|
||||
if k not in {"lora_controlnet"}:
|
||||
ldm_patched.modules.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(ldm_patched.modules.model_management.get_torch_device()))
|
||||
utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(memory_management.get_torch_device()))
|
||||
|
||||
def copy(self):
|
||||
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
||||
@@ -409,117 +390,8 @@ class ControlLora(ControlNet):
|
||||
return out
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return ldm_patched.modules.utils.calculate_parameters(self.control_weights) * ldm_patched.modules.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
return utils.calculate_parameters(self.control_weights) * memory_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
return ControlLora(controlnet_data)
|
||||
|
||||
controlnet_config = None
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
|
||||
controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
||||
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config)
|
||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||
|
||||
count = 0
|
||||
loop = True
|
||||
while loop:
|
||||
suffix = [".weight", ".bias"]
|
||||
for s in suffix:
|
||||
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
||||
k_out = "zero_convs.{}.0{}".format(count, s)
|
||||
if k_in not in controlnet_data:
|
||||
loop = False
|
||||
break
|
||||
diffusers_keys[k_in] = k_out
|
||||
count += 1
|
||||
|
||||
count = 0
|
||||
loop = True
|
||||
while loop:
|
||||
suffix = [".weight", ".bias"]
|
||||
for s in suffix:
|
||||
if count == 0:
|
||||
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
||||
else:
|
||||
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
||||
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
||||
if k_in not in controlnet_data:
|
||||
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
||||
loop = False
|
||||
diffusers_keys[k_in] = k_out
|
||||
count += 1
|
||||
|
||||
new_sd = {}
|
||||
for k in diffusers_keys:
|
||||
if k in controlnet_data:
|
||||
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
||||
|
||||
leftover_keys = controlnet_data.keys()
|
||||
if len(leftover_keys) > 0:
|
||||
print("leftover keys:", leftover_keys)
|
||||
controlnet_data = new_sd
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
key = 'zero_convs.0.0.weight'
|
||||
if pth_key in controlnet_data:
|
||||
pth = True
|
||||
key = pth_key
|
||||
prefix = "control_model."
|
||||
elif key in controlnet_data:
|
||||
prefix = ""
|
||||
else:
|
||||
net = load_t2i_adapter(controlnet_data)
|
||||
if net is None:
|
||||
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
|
||||
controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||
load_device = ldm_patched.modules.model_management.get_torch_device()
|
||||
manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
|
||||
|
||||
if pth:
|
||||
if 'difference' in controlnet_data:
|
||||
if model is not None:
|
||||
ldm_patched.modules.model_management.load_models_gpu([model])
|
||||
model_sd = model.model_state_dict()
|
||||
for x in controlnet_data:
|
||||
c_m = "control_model."
|
||||
if x.startswith(c_m):
|
||||
sd_key = "diffusion_model.{}".format(x[len(c_m):])
|
||||
if sd_key in model_sd:
|
||||
cd = controlnet_data[x]
|
||||
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
||||
else:
|
||||
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.control_model = control_model
|
||||
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
||||
else:
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
print(missing, unexpected)
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
class T2IAdapter(ControlBase):
|
||||
def __init__(self, t2i_model, channels_in, device=None):
|
||||
@@ -557,7 +429,7 @@ class T2IAdapter(ControlBase):
|
||||
self.control_input = None
|
||||
self.cond_hint = None
|
||||
width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
|
||||
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float()
|
||||
self.cond_hint = image_resize.adaptive_resize(self.cond_hint_original, width, height, 'nearest-exact', "center").float()
|
||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
@@ -591,22 +463,23 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
|
||||
def load_t2i_adapter(t2i_data):
|
||||
if 'adapter' in t2i_data:
|
||||
t2i_data = t2i_data['adapter']
|
||||
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
|
||||
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: # diffusers format
|
||||
prefix_replace = {}
|
||||
for i in range(4):
|
||||
for j in range(2):
|
||||
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
|
||||
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
|
||||
prefix_replace["adapter."] = ""
|
||||
t2i_data = ldm_patched.modules.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
|
||||
t2i_data = state_dict.state_dict_prefix_replace(t2i_data, prefix_replace)
|
||||
keys = t2i_data.keys()
|
||||
|
||||
if "body.0.in_conv.weight" in keys:
|
||||
cin = t2i_data['body.0.in_conv.weight'].shape[1]
|
||||
model_ad = ldm_patched.t2ia.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
|
||||
model_ad = t2i_adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
|
||||
elif 'conv_in.weight' in keys:
|
||||
cin = t2i_data['conv_in.weight'].shape[1]
|
||||
channel = t2i_data['conv_in.weight'].shape[0]
|
||||
@@ -618,9 +491,10 @@ def load_t2i_adapter(t2i_data):
|
||||
xl = False
|
||||
if cin == 256 or cin == 768:
|
||||
xl = True
|
||||
model_ad = ldm_patched.t2ia.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||
model_ad = t2i_adapter.Adapter(cin=cin, channels=[channel, channel * 2, channel * 4, channel * 4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||
else:
|
||||
return None
|
||||
|
||||
missing, unexpected = model_ad.load_state_dict(t2i_data)
|
||||
if len(missing) > 0:
|
||||
print("t2i missing", missing)
|
||||
@@ -28,3 +28,11 @@ def get_attr(obj, attr):
|
||||
for name in attrs:
|
||||
obj = getattr(obj, name)
|
||||
return obj
|
||||
|
||||
|
||||
def calculate_parameters(sd, prefix=""):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
params += sd[k].nelement()
|
||||
return params
|
||||
|
||||
@@ -14,7 +14,7 @@ from ldm_patched.modules.model_base import BaseModel
|
||||
from typing import List, Union, Tuple, Dict
|
||||
from ldm_patched.contrib.external import ImageScale
|
||||
import ldm_patched.modules.utils
|
||||
from ldm_patched.modules.controlnet import ControlNet, T2IAdapter
|
||||
from backend.patcher.controlnet import ControlNet, T2IAdapter
|
||||
|
||||
opt_C = 4
|
||||
opt_f = 8
|
||||
|
||||
@@ -26,7 +26,7 @@ import ldm_patched.modules.samplers
|
||||
import ldm_patched.modules.sample
|
||||
import ldm_patched.modules.sd
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.modules.controlnet
|
||||
# import ldm_patched.modules.controlnet
|
||||
|
||||
import ldm_patched.modules.clip_vision
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from . import sdxl_clip
|
||||
|
||||
import ldm_patched.modules.model_patcher
|
||||
import ldm_patched.modules.lora
|
||||
import ldm_patched.t2ia.adapter
|
||||
# import ldm_patched.t2ia.adapter
|
||||
import ldm_patched.modules.supported_models_base
|
||||
import ldm_patched.taesd.taesd
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import torch
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.controlnet
|
||||
|
||||
from ldm_patched.modules.controlnet import ControlLora, ControlNet, load_t2i_adapter
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.cnets import cldm
|
||||
from backend.patcher.controlnet import ControlLora, ControlNet, load_t2i_adapter
|
||||
from modules_forge.controlnet import apply_controlnet_advanced
|
||||
from modules_forge.shared import add_supported_control_model
|
||||
|
||||
@@ -43,8 +44,7 @@ class ControlNetPatcher(ControlModelPatcher):
|
||||
controlnet_config = None
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format
|
||||
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
|
||||
controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data,
|
||||
unet_dtype)
|
||||
controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
||||
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config)
|
||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||
@@ -105,15 +105,16 @@ class ControlNetPatcher(ControlModelPatcher):
|
||||
|
||||
if controlnet_config is None:
|
||||
unet_dtype = ldm_patched.modules.model_management.unet_dtype()
|
||||
controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix,
|
||||
unet_dtype, True).unet_config
|
||||
controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||
|
||||
load_device = ldm_patched.modules.model_management.get_torch_device()
|
||||
manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
|
||||
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
|
||||
|
||||
with using_forge_operations(parameters_manual_cast=manual_cast_dtype is not None):
|
||||
control_model = cldm.ControlNet(**controlnet_config)
|
||||
|
||||
if pth:
|
||||
if 'difference' in controlnet_data:
|
||||
@@ -136,8 +137,7 @@ class ControlNetPatcher(ControlModelPatcher):
|
||||
# TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype)
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return ControlNetPatcher(control)
|
||||
|
||||
def __init__(self, model_patcher):
|
||||
|
||||
Reference in New Issue
Block a user