From 015f0967e1baeb2bd6842486a9e33c2184f7696d Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 30 Jan 2024 12:17:31 -0800 Subject: [PATCH] fully patch unet --- modules_forge/unet_patcher.py | 55 ++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 4 deletions(-) diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index 175138df..11895e38 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -47,19 +47,38 @@ class UnetPatcher(ModelPatcher): self.model_options[k].append(v) return + def append_transformer_option(self, k, v, ensure_uniqueness=False): + if 'transformer_options' not in self.model_options: + self.model_options['transformer_options'] = {} + + to = self.model_options['transformer_options'] + + if k not in to: + to[k] = [] + + if ensure_uniqueness and v in to[k]: + return + + to[k].append(v) + return + def add_conditioning_modifier(self, modifier, ensure_uniqueness=False): self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness) return def add_block_modifier(self, modifier, ensure_uniqueness=False): - self.append_model_option('block_modifiers', modifier, ensure_uniqueness) + self.append_transformer_option('block_modifiers', modifier, ensure_uniqueness) return -def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): +def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options=None, **kwargs): + if transformer_options is None: + transformer_options = {} + transformer_options["original_shape"] = list(x.shape) transformer_options["transformer_index"] = 0 transformer_patches = transformer_options.get("patches", {}) + block_modifiers = transformer_options.get("block_modifiers", []) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) @@ -79,9 +98,17 @@ def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=No h = x for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) + h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'input') + + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) + if "input_block_patch" in transformer_patches: patch = transformer_patches["input_block_patch"] for p in patch: @@ -94,10 +121,17 @@ def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=No h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) + for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) hsp = hs.pop() @@ -114,18 +148,31 @@ def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=No output_shape = hs[-1].shape else: output_shape = None + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) - h = h.type(x.dtype) + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) + + transformer_options["block"] = ("last", 0) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) if self.predict_codebook_ids: h = self.id_predictor(h) else: h = self.out(h) - return h + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) + + return h.type(x.dtype) def patch_all():