This commit is contained in:
lllyasviel
2024-01-30 11:59:13 -08:00
parent 0f532823b2
commit d7991ca846
2 changed files with 86 additions and 0 deletions

View File

@@ -33,6 +33,9 @@ def initialize_forge():
import modules_forge.patch_basic
modules_forge.patch_basic.patch_all_basics()
import modules_forge.unet_patcher
modules_forge.unet_patcher.patch_all()
if model_management.directml_enabled:
model_management.lowvram_available = True
model_management.OOM_EXCEPTION = Exception

View File

@@ -1,4 +1,7 @@
import copy
import torch
from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, timestep_embedding, forward_timestep_embed, apply_control
from ldm_patched.modules.model_patcher import ModelPatcher
@@ -47,3 +50,83 @@ class UnetPatcher(ModelPatcher):
def add_conditioning_modifier(self, modifier, ensure_uniqueness=False):
self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness)
return
def add_block_modifier(self, modifier, ensure_uniqueness=False):
self.append_model_option('block_modifiers', modifier, ensure_uniqueness)
return
def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
transformer_options["original_shape"] = list(x.shape)
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context,
num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input')
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context,
num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
hsp = apply_control(hsp, control, 'output')
if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
for p in patch:
h, hsp = p(h, hsp, transformer_options)
h = torch.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape,
time_context=time_context, num_video_frames=num_video_frames,
image_only_indicator=image_only_indicator)
h = h.type(x.dtype)
if self.predict_codebook_ids:
h = self.id_predictor(h)
else:
h = self.out(h)
return h
def patch_all():
UNetModel.forward = forge_unet_forward