mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-10 18:09:58 +00:00
fully patch unet
This commit is contained in:
@@ -47,19 +47,38 @@ class UnetPatcher(ModelPatcher):
|
||||
self.model_options[k].append(v)
|
||||
return
|
||||
|
||||
def append_transformer_option(self, k, v, ensure_uniqueness=False):
|
||||
if 'transformer_options' not in self.model_options:
|
||||
self.model_options['transformer_options'] = {}
|
||||
|
||||
to = self.model_options['transformer_options']
|
||||
|
||||
if k not in to:
|
||||
to[k] = []
|
||||
|
||||
if ensure_uniqueness and v in to[k]:
|
||||
return
|
||||
|
||||
to[k].append(v)
|
||||
return
|
||||
|
||||
def add_conditioning_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_block_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_model_option('block_modifiers', modifier, ensure_uniqueness)
|
||||
self.append_transformer_option('block_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
|
||||
def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options=None, **kwargs):
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
|
||||
transformer_options["original_shape"] = list(x.shape)
|
||||
transformer_options["transformer_index"] = 0
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
block_modifiers = transformer_options.get("block_modifiers", [])
|
||||
|
||||
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
||||
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
|
||||
@@ -79,9 +98,17 @@ def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=No
|
||||
h = x
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context,
|
||||
num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'input')
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
if "input_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["input_block_patch"]
|
||||
for p in patch:
|
||||
@@ -94,10 +121,17 @@ def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=No
|
||||
h = p(h, transformer_options)
|
||||
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context,
|
||||
num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
hsp = hs.pop()
|
||||
@@ -114,18 +148,31 @@ def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=No
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape,
|
||||
time_context=time_context, num_video_frames=num_video_frames,
|
||||
image_only_indicator=image_only_indicator)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
transformer_options["block"] = ("last", 0)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
if self.predict_codebook_ids:
|
||||
h = self.id_predictor(h)
|
||||
else:
|
||||
h = self.out(h)
|
||||
|
||||
return h
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
return h.type(x.dtype)
|
||||
|
||||
|
||||
def patch_all():
|
||||
|
||||
Reference in New Issue
Block a user