From 33b3fe9701d73ae59ca8280eed1b1d2c81d3f895 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 27 Jan 2024 23:11:56 -0800 Subject: [PATCH] Update patch_basic.py --- modules_forge/patch_basic.py | 53 ++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index 3f36a674..ce24ff71 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -1,3 +1,6 @@ +import torch + +from ldm_patched.modules.controlnet import ControlBase from ldm_patched.modules.model_patcher import ModelPatcher @@ -32,9 +35,59 @@ def model_patcher_list_controlnets(self): return results +def patched_control_merge(self, control_input, control_output, control_prev, output_dtype): + out = {'input': [], 'middle': [], 'output': []} + + if control_input is not None: + for i in range(len(control_input)): + key = 'input' + x = control_input[i] + if x is not None: + x *= self.strength + if x.dtype != output_dtype: + x = x.to(output_dtype) + out[key].insert(0, x) + + if control_output is not None: + for i in range(len(control_output)): + if i == (len(control_output) - 1): + key = 'middle' + index = 0 + else: + key = 'output' + index = i + x = control_output[i] + if x is not None: + if self.global_average_pooling: + x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) + + x *= self.strength + if x.dtype != output_dtype: + x = x.to(output_dtype) + + out[key].append(x) + if control_prev is not None: + for x in ['input', 'middle', 'output']: + o = out[x] + for i in range(len(control_prev[x])): + prev_val = control_prev[x][i] + if i >= len(o): + o.append(prev_val) + elif prev_val is not None: + if o[i] is None: + o[i] = prev_val + else: + if o[i].shape[0] < prev_val.shape[0]: + o[i] = prev_val + o[i] + else: + o[i] += prev_val + return out + + def patch_all_basics(): ModelPatcher.__init__ = patched_model_patcher_init ModelPatcher.clone = patched_model_patcher_clone ModelPatcher.add_patched_controlnet = model_patcher_add_patched_controlnet ModelPatcher.list_controlnets = model_patcher_list_controlnets + ControlBase.control_merge = patched_control_merge return