This commit is contained in:
lllyasviel
2024-01-31 13:16:41 -08:00
parent 86bd258a8e
commit 67c38f9294
11 changed files with 167 additions and 68 deletions

View File

@@ -825,6 +825,7 @@ class UNetModel(nn.Module):
transformer_options["original_shape"] = list(x.shape)
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
block_modifiers = transformer_options.get("block_modifiers", [])
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
@@ -844,8 +845,16 @@ class UNetModel(nn.Module):
h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input')
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
@@ -858,9 +867,15 @@ class UNetModel(nn.Module):
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
@@ -878,9 +893,26 @@ class UNetModel(nn.Module):
output_shape = hs[-1].shape
else:
output_shape = None
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = h.type(x.dtype)
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
transformer_options["block"] = ("last", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
if self.predict_codebook_ids:
return self.id_predictor(h)
h = self.id_predictor(h)
else:
return self.out(h)
h = self.out(h)
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
return h.type(x.dtype)