This commit is contained in:
lllyasviel
2024-01-30 19:15:47 -08:00
parent df673fa6dc
commit aa5d960c8e
2 changed files with 19 additions and 20 deletions

View File

@@ -3,6 +3,11 @@ import torch
from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter
from modules_forge.shared import add_supported_preprocessor
from ldm_patched.modules.samplers import sampling_function
import ldm_patched.ldm.modules.attention as attention
def sdp(q, k, v, transformer_options):
return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None)
class PreprocessorReference(Preprocessor):
@@ -59,6 +64,9 @@ class PreprocessorReference(Preprocessor):
if not self.use_adain:
return h
if flag != 'after':
return
sigma = transformer_options["sigmas"][0].item()
if not (sigma_min <= sigma <= sigma_max):
return h
@@ -72,38 +80,22 @@ class PreprocessorReference(Preprocessor):
def attn1_proc(q, k, v, transformer_options):
if not self.use_attn:
return q, k, v
return sdp(q, k, v, transformer_options)
sigma = transformer_options["sigmas"][0].item()
if not (sigma_min <= sigma <= sigma_max):
return q, k, v
return sdp(q, k, v, transformer_options)
if self.is_recording_style:
a = 0
else:
b = 0
return q, k, v
def attn1_output_proc(h, transformer_options):
if not self.use_attn:
return h
sigma = transformer_options["sigmas"][0].item()
if not (sigma_min <= sigma <= sigma_max):
return h
if self.is_recording_style:
a = 0
else:
b = 0
return h
return sdp(q, k, v, transformer_options)
unet.add_block_modifier(block_proc)
unet.add_conditioning_modifier(conditioning_modifier)
unet.set_model_attn1_patch(attn1_proc)
unet.set_model_attn1_output_patch(attn1_output_proc)
unet.set_model_replace_all(attn1_proc, 'attn1')
process.sd_model.forge_objects.unet = unet