From d7991ca8462f2b51ab80aee9b857cb3e1918d5f1 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 30 Jan 2024 11:59:13 -0800 Subject: [PATCH] i --- modules_forge/initialization.py | 3 ++ modules_forge/unet_patcher.py | 83 +++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 6c3839e7..415f33af 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -33,6 +33,9 @@ def initialize_forge(): import modules_forge.patch_basic modules_forge.patch_basic.patch_all_basics() + import modules_forge.unet_patcher + modules_forge.unet_patcher.patch_all() + if model_management.directml_enabled: model_management.lowvram_available = True model_management.OOM_EXCEPTION = Exception diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index afdabfd8..175138df 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -1,4 +1,7 @@ import copy +import torch + +from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, timestep_embedding, forward_timestep_embed, apply_control from ldm_patched.modules.model_patcher import ModelPatcher @@ -47,3 +50,83 @@ class UnetPatcher(ModelPatcher): def add_conditioning_modifier(self, modifier, ensure_uniqueness=False): self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness) return + + def add_block_modifier(self, modifier, ensure_uniqueness=False): + self.append_model_option('block_modifiers', modifier, ensure_uniqueness) + return + + +def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): + transformer_options["original_shape"] = list(x.shape) + transformer_options["transformer_index"] = 0 + transformer_patches = transformer_options.get("patches", {}) + + num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) + image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + time_context = kwargs.get("time_context", None) + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for id, module in enumerate(self.input_blocks): + transformer_options["block"] = ("input", id) + h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, + num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) + h = apply_control(h, control, 'input') + if "input_block_patch" in transformer_patches: + patch = transformer_patches["input_block_patch"] + for p in patch: + h = p(h, transformer_options) + + hs.append(h) + if "input_block_patch_after_skip" in transformer_patches: + patch = transformer_patches["input_block_patch_after_skip"] + for p in patch: + h = p(h, transformer_options) + + transformer_options["block"] = ("middle", 0) + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, + num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) + h = apply_control(h, control, 'middle') + + for id, module in enumerate(self.output_blocks): + transformer_options["block"] = ("output", id) + hsp = hs.pop() + hsp = apply_control(hsp, control, 'output') + + if "output_block_patch" in transformer_patches: + patch = transformer_patches["output_block_patch"] + for p in patch: + h, hsp = p(h, hsp, transformer_options) + + h = torch.cat([h, hsp], dim=1) + del hsp + if len(hs) > 0: + output_shape = hs[-1].shape + else: + output_shape = None + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, + time_context=time_context, num_video_frames=num_video_frames, + image_only_indicator=image_only_indicator) + + h = h.type(x.dtype) + + if self.predict_codebook_ids: + h = self.id_predictor(h) + else: + h = self.out(h) + + return h + + +def patch_all(): + UNetModel.forward = forge_unet_forward