Files
stable-diffusion-webui-forge/modules_forge/patch_basic.py
2024-01-27 23:11:56 -08:00

94 lines
2.9 KiB
Python

import torch
from ldm_patched.modules.controlnet import ControlBase
from ldm_patched.modules.model_patcher import ModelPatcher
og_model_patcher_init = ModelPatcher.__init__
og_model_patcher_clone = ModelPatcher.clone
def patched_model_patcher_init(self, *args, **kwargs):
h = og_model_patcher_init(self, *args, **kwargs)
self.controlnet_linked_list = None
return h
def patched_model_patcher_clone(self):
cloned = og_model_patcher_clone(self)
cloned.controlnet_linked_list = self.controlnet_linked_list
return cloned
def model_patcher_add_patched_controlnet(self, cnet):
cnet.set_previous_controlnet(self.controlnet_linked_list)
self.controlnet_linked_list = cnet
return
def model_patcher_list_controlnets(self):
results = []
pointer = self.controlnet_linked_list
while pointer is not None:
results.append(pointer)
pointer = pointer.previous_controlnet
return results
def patched_control_merge(self, control_input, control_output, control_prev, output_dtype):
out = {'input': [], 'middle': [], 'output': []}
if control_input is not None:
for i in range(len(control_input)):
key = 'input'
x = control_input[i]
if x is not None:
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].insert(0, x)
if control_output is not None:
for i in range(len(control_output)):
if i == (len(control_output) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control_output[i]
if x is not None:
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
if control_prev is not None:
for x in ['input', 'middle', 'output']:
o = out[x]
for i in range(len(control_prev[x])):
prev_val = control_prev[x][i]
if i >= len(o):
o.append(prev_val)
elif prev_val is not None:
if o[i] is None:
o[i] = prev_val
else:
if o[i].shape[0] < prev_val.shape[0]:
o[i] = prev_val + o[i]
else:
o[i] += prev_val
return out
def patch_all_basics():
ModelPatcher.__init__ = patched_model_patcher_init
ModelPatcher.clone = patched_model_patcher_clone
ModelPatcher.add_patched_controlnet = model_patcher_add_patched_controlnet
ModelPatcher.list_controlnets = model_patcher_list_controlnets
ControlBase.control_merge = patched_control_merge
return