forge 2.0.0

see also discussions
This commit is contained in:
lllyasviel
2024-08-10 19:24:19 -07:00
committed by GitHub
parent 4014013d05
commit cfa5242a75
28 changed files with 785 additions and 1249 deletions

View File

@@ -433,9 +433,9 @@ class ResBlock(TimestepBlock):
def _forward(self, x, emb, transformer_options={}):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
if "groupnorm_wrapper" in transformer_options:
if "group_norm_wrapper" in transformer_options:
in_norm, in_rest = in_rest[0], in_rest[1:]
h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options)
h = transformer_options["group_norm_wrapper"](in_norm, x, transformer_options)
h = in_rest(h)
else:
h = in_rest(x)
@@ -443,9 +443,9 @@ class ResBlock(TimestepBlock):
x = self.x_upd(x)
h = in_conv(h)
else:
if "groupnorm_wrapper" in transformer_options:
if "group_norm_wrapper" in transformer_options:
in_norm = self.in_layers[0]
h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options)
h = transformer_options["group_norm_wrapper"](in_norm, x, transformer_options)
h = self.in_layers[1:](h)
else:
h = self.in_layers(x)
@@ -456,8 +456,8 @@ class ResBlock(TimestepBlock):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
if "groupnorm_wrapper" in transformer_options:
h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options)
if "group_norm_wrapper" in transformer_options:
h = transformer_options["group_norm_wrapper"](out_norm, h, transformer_options)
else:
h = out_norm(h)
if emb_out is not None:
@@ -470,8 +470,8 @@ class ResBlock(TimestepBlock):
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out
if "groupnorm_wrapper" in transformer_options:
h = transformer_options["groupnorm_wrapper"](self.out_layers[0], h, transformer_options)
if "group_norm_wrapper" in transformer_options:
h = transformer_options["group_norm_wrapper"](self.out_layers[0], h, transformer_options)
h = self.out_layers[1:](h)
else:
h = self.out_layers(h)
@@ -752,9 +752,9 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
transformer_options["block"] = ("last", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
if "groupnorm_wrapper" in transformer_options:
if "group_norm_wrapper" in transformer_options:
out_norm, out_rest = self.out[0], self.out[1:]
h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options)
h = transformer_options["group_norm_wrapper"](out_norm, h, transformer_options)
h = out_rest(h)
else:
h = self.out(h)