From 5192e912ab120947627aa3e4055ff3445a025b83 Mon Sep 17 00:00:00 2001 From: continue revolution Date: Fri, 2 Aug 2024 03:53:27 +0800 Subject: [PATCH] GroupNorm patcher (#593) * add gn wrapper and corresponding patcher * add gn wrapper and corresponding patcher --- .../modules/diffusionmodules/openaimodel.py | 39 ++++++++++++++----- modules_forge/unet_patcher.py | 4 ++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py b/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py index 9904e744..0f34e69c 100644 --- a/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py @@ -30,7 +30,7 @@ class TimestepBlock(nn.Module): """ @abstractmethod - def forward(self, x, emb): + def forward(self, x, emb, transformer_options={}): """ Apply the module to `x` given `emb` timestep embeddings. """ @@ -46,7 +46,7 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out if isinstance(layer, VideoResBlock): x = layer(x, emb, num_video_frames, image_only_indicator) elif isinstance(layer, TimestepBlock): - x = layer(x, emb) + x = layer(x, emb, transformer_options) elif isinstance(layer, SpatialVideoTransformer): x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options) if "transformer_index" in transformer_options: @@ -234,7 +234,7 @@ class ResBlock(TimestepBlock): else: self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) - def forward(self, x, emb): + def forward(self, x, emb, transformer_options={}): """ Apply the block to a Tensor, conditioned on a timestep embedding. :param x: an [N x C x ...] Tensor of features. @@ -242,19 +242,29 @@ class ResBlock(TimestepBlock): :return: an [N x C x ...] Tensor of outputs. """ return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint + self._forward, (x, emb, transformer_options), self.parameters(), self.use_checkpoint ) - def _forward(self, x, emb): + def _forward(self, x, emb, transformer_options={}): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) + if "groupnorm_wrapper" in transformer_options: + in_norm, in_rest = in_rest[0], in_rest[1:] + h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options) + h = in_rest(h) + else: + h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: - h = self.in_layers(x) + if "groupnorm_wrapper" in transformer_options: + in_norm = self.in_layers[0] + h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options) + h = self.in_layers[1:](h) + else: + h = self.in_layers(x) emb_out = None if not self.skip_t_emb: @@ -263,7 +273,10 @@ class ResBlock(TimestepBlock): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - h = out_norm(h) + if "groupnorm_wrapper" in transformer_options: + h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options) + else: + h = out_norm(h) if emb_out is not None: scale, shift = th.chunk(emb_out, 2, dim=1) h *= (1 + scale) @@ -274,7 +287,11 @@ class ResBlock(TimestepBlock): if self.exchange_temb_dims: emb_out = rearrange(emb_out, "b t c ... -> b c t ...") h = h + emb_out - h = self.out_layers(h) + if "groupnorm_wrapper" in transformer_options: + h = transformer_options["groupnorm_wrapper"](self.out_layers[0], h, transformer_options) + h = self.out_layers[1:](h) + else: + h = self.out_layers(h) return self.skip_connection(x) + h @@ -924,6 +941,10 @@ class UNetModel(nn.Module): if self.predict_codebook_ids: h = self.id_predictor(h) + elif "groupnorm_wrapper" in transformer_options: + out_norm, out_rest = self.out[0], self.out[1:] + h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options) + h = out_rest(h) else: h = self.out(h) diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index 01164c09..4b1d14f8 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -159,6 +159,10 @@ class UnetPatcher(ModelPatcher): self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness) return + def set_groupnorm_wrapper(self, wrapper): + self.set_transformer_option('groupnorm_wrapper', wrapper) + return + def set_controlnet_model_function_wrapper(self, wrapper): self.set_transformer_option('controlnet_model_function_wrapper', wrapper) return