From aa5d960c8e73eb8f85a8f4c74105cbe190beb166 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 30 Jan 2024 19:15:47 -0800 Subject: [PATCH] i --- .../scripts/forge_reference.py | 32 +++++++------------ modules_forge/unet_patcher.py | 7 ++++ 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py index 644345a1..ad6b48ab 100644 --- a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py +++ b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py @@ -3,6 +3,11 @@ import torch from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter from modules_forge.shared import add_supported_preprocessor from ldm_patched.modules.samplers import sampling_function +import ldm_patched.ldm.modules.attention as attention + + +def sdp(q, k, v, transformer_options): + return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None) class PreprocessorReference(Preprocessor): @@ -59,6 +64,9 @@ class PreprocessorReference(Preprocessor): if not self.use_adain: return h + if flag != 'after': + return + sigma = transformer_options["sigmas"][0].item() if not (sigma_min <= sigma <= sigma_max): return h @@ -72,38 +80,22 @@ class PreprocessorReference(Preprocessor): def attn1_proc(q, k, v, transformer_options): if not self.use_attn: - return q, k, v + return sdp(q, k, v, transformer_options) sigma = transformer_options["sigmas"][0].item() if not (sigma_min <= sigma <= sigma_max): - return q, k, v + return sdp(q, k, v, transformer_options) if self.is_recording_style: a = 0 else: b = 0 - return q, k, v - - def attn1_output_proc(h, transformer_options): - if not self.use_attn: - return h - - sigma = transformer_options["sigmas"][0].item() - if not (sigma_min <= sigma <= sigma_max): - return h - - if self.is_recording_style: - a = 0 - else: - b = 0 - - return h + return sdp(q, k, v, transformer_options) unet.add_block_modifier(block_proc) unet.add_conditioning_modifier(conditioning_modifier) - unet.set_model_attn1_patch(attn1_proc) - unet.set_model_attn1_output_patch(attn1_output_proc) + unet.set_model_replace_all(attn1_proc, 'attn1') process.sd_model.forge_objects.unet = unet diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index 11895e38..dcc0c382 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -70,6 +70,13 @@ class UnetPatcher(ModelPatcher): self.append_transformer_option('block_modifiers', modifier, ensure_uniqueness) return + def set_model_replace_all(self, patch, target="attn1"): + for block_name in ['input', 'middle', 'output']: + for number in range(64): + for transformer_index in range(64): + self.set_model_patch_replace(patch, target, block_name, number, transformer_index) + return + def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options=None, **kwargs): if transformer_options is None: