mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-04 06:59:59 +00:00
i
This commit is contained in:
@@ -3,6 +3,11 @@ import torch
|
||||
from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter
|
||||
from modules_forge.shared import add_supported_preprocessor
|
||||
from ldm_patched.modules.samplers import sampling_function
|
||||
import ldm_patched.ldm.modules.attention as attention
|
||||
|
||||
|
||||
def sdp(q, k, v, transformer_options):
|
||||
return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None)
|
||||
|
||||
|
||||
class PreprocessorReference(Preprocessor):
|
||||
@@ -59,6 +64,9 @@ class PreprocessorReference(Preprocessor):
|
||||
if not self.use_adain:
|
||||
return h
|
||||
|
||||
if flag != 'after':
|
||||
return
|
||||
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if not (sigma_min <= sigma <= sigma_max):
|
||||
return h
|
||||
@@ -72,38 +80,22 @@ class PreprocessorReference(Preprocessor):
|
||||
|
||||
def attn1_proc(q, k, v, transformer_options):
|
||||
if not self.use_attn:
|
||||
return q, k, v
|
||||
return sdp(q, k, v, transformer_options)
|
||||
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if not (sigma_min <= sigma <= sigma_max):
|
||||
return q, k, v
|
||||
return sdp(q, k, v, transformer_options)
|
||||
|
||||
if self.is_recording_style:
|
||||
a = 0
|
||||
else:
|
||||
b = 0
|
||||
|
||||
return q, k, v
|
||||
|
||||
def attn1_output_proc(h, transformer_options):
|
||||
if not self.use_attn:
|
||||
return h
|
||||
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if not (sigma_min <= sigma <= sigma_max):
|
||||
return h
|
||||
|
||||
if self.is_recording_style:
|
||||
a = 0
|
||||
else:
|
||||
b = 0
|
||||
|
||||
return h
|
||||
return sdp(q, k, v, transformer_options)
|
||||
|
||||
unet.add_block_modifier(block_proc)
|
||||
unet.add_conditioning_modifier(conditioning_modifier)
|
||||
unet.set_model_attn1_patch(attn1_proc)
|
||||
unet.set_model_attn1_output_patch(attn1_output_proc)
|
||||
unet.set_model_replace_all(attn1_proc, 'attn1')
|
||||
|
||||
process.sd_model.forge_objects.unet = unet
|
||||
|
||||
|
||||
@@ -70,6 +70,13 @@ class UnetPatcher(ModelPatcher):
|
||||
self.append_transformer_option('block_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def set_model_replace_all(self, patch, target="attn1"):
|
||||
for block_name in ['input', 'middle', 'output']:
|
||||
for number in range(64):
|
||||
for transformer_index in range(64):
|
||||
self.set_model_patch_replace(patch, target, block_name, number, transformer_index)
|
||||
return
|
||||
|
||||
|
||||
def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options=None, **kwargs):
|
||||
if transformer_options is None:
|
||||
|
||||
Reference in New Issue
Block a user