From a3dffecb3fbad1175ac3cd739629d340c454a1e9 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sun, 28 Jan 2024 06:21:53 -0800 Subject: [PATCH] i --- modules_forge/forge_loader.py | 3 ++- modules_forge/patch_basic.py | 36 ----------------------------------- modules_forge/unet_patcher.py | 35 ++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 37 deletions(-) create mode 100644 modules_forge/unet_patcher.py diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index 7c29da5d..b06a58a7 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -16,6 +16,7 @@ from modules import sd_hijack from modules.sd_models_xl import extend_sdxl from ldm.util import instantiate_from_config from modules_forge import forge_clip +from modules_forge.unet_patcher import UnetPatcher import open_clip from transformers import CLIPTextModel, CLIPTokenizer @@ -116,7 +117,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c print("left over keys:", left_over) if output_model: - model_patcher = ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + model_patcher = UnetPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index 46045907..d4bff6ed 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -2,42 +2,10 @@ import torch import ldm_patched.modules.samplers from ldm_patched.modules.controlnet import ControlBase -from ldm_patched.modules.model_patcher import ModelPatcher from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat from ldm_patched.modules import model_management -og_model_patcher_init = ModelPatcher.__init__ -og_model_patcher_clone = ModelPatcher.clone - - -def patched_model_patcher_init(self, *args, **kwargs): - h = og_model_patcher_init(self, *args, **kwargs) - self.controlnet_linked_list = None - return h - - -def patched_model_patcher_clone(self): - cloned = og_model_patcher_clone(self) - cloned.controlnet_linked_list = self.controlnet_linked_list - return cloned - - -def model_patcher_add_patched_controlnet(self, cnet): - cnet.set_previous_controlnet(self.controlnet_linked_list) - self.controlnet_linked_list = cnet - return - - -def model_patcher_list_controlnets(self): - results = [] - pointer = self.controlnet_linked_list - while pointer is not None: - results.append(pointer) - pointer = pointer.previous_controlnet - return results - - def patched_control_merge(self, control_input, control_output, control_prev, output_dtype): out = {'input': [], 'middle': [], 'output': []} @@ -208,10 +176,6 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op def patch_all_basics(): - ModelPatcher.__init__ = patched_model_patcher_init - ModelPatcher.clone = patched_model_patcher_clone - ModelPatcher.add_patched_controlnet = model_patcher_add_patched_controlnet - ModelPatcher.list_controlnets = model_patcher_list_controlnets ControlBase.control_merge = patched_control_merge ldm_patched.modules.samplers.calc_cond_uncond_batch = patched_calc_cond_uncond_batch return diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py new file mode 100644 index 00000000..53df477e --- /dev/null +++ b/modules_forge/unet_patcher.py @@ -0,0 +1,35 @@ +import copy +from ldm_patched.modules.model_patcher import ModelPatcher + + +class UnetPatcher(ModelPatcher): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.controlnet_linked_list = None + + def clone(self): + n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, + weight_inplace_update=self.weight_inplace_update) + + n.patches = {} + for k in self.patches: + n.patches[k] = self.patches[k][:] + + n.object_patches = self.object_patches.copy() + n.model_options = copy.deepcopy(self.model_options) + n.model_keys = self.model_keys + n.controlnet_linked_list = self.controlnet_linked_list + return n + + def add_patched_controlnet(self, cnet): + cnet.set_previous_controlnet(self.controlnet_linked_list) + self.controlnet_linked_list = cnet + return + + def list_controlnets(self): + results = [] + pointer = self.controlnet_linked_list + while pointer is not None: + results.append(pointer) + pointer = pointer.previous_controlnet + return results