GroupNorm patcher (#593)

* add gn wrapper and corresponding patcher

* add gn wrapper and corresponding patcher
This commit is contained in:
continue revolution
2024-08-02 03:53:27 +08:00
committed by GitHub
parent 0eea1acc6e
commit 5192e912ab
2 changed files with 34 additions and 9 deletions

View File

@@ -30,7 +30,7 @@ class TimestepBlock(nn.Module):
"""
@abstractmethod
def forward(self, x, emb):
def forward(self, x, emb, transformer_options={}):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
@@ -46,7 +46,7 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock):
x = layer(x, emb)
x = layer(x, emb, transformer_options)
elif isinstance(layer, SpatialVideoTransformer):
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
if "transformer_index" in transformer_options:
@@ -234,7 +234,7 @@ class ResBlock(TimestepBlock):
else:
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
def forward(self, x, emb):
def forward(self, x, emb, transformer_options={}):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
@@ -242,19 +242,29 @@ class ResBlock(TimestepBlock):
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
self._forward, (x, emb, transformer_options), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
def _forward(self, x, emb, transformer_options={}):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
if "groupnorm_wrapper" in transformer_options:
in_norm, in_rest = in_rest[0], in_rest[1:]
h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options)
h = in_rest(h)
else:
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
if "groupnorm_wrapper" in transformer_options:
in_norm = self.in_layers[0]
h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options)
h = self.in_layers[1:](h)
else:
h = self.in_layers(x)
emb_out = None
if not self.skip_t_emb:
@@ -263,7 +273,10 @@ class ResBlock(TimestepBlock):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
h = out_norm(h)
if "groupnorm_wrapper" in transformer_options:
h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options)
else:
h = out_norm(h)
if emb_out is not None:
scale, shift = th.chunk(emb_out, 2, dim=1)
h *= (1 + scale)
@@ -274,7 +287,11 @@ class ResBlock(TimestepBlock):
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out
h = self.out_layers(h)
if "groupnorm_wrapper" in transformer_options:
h = transformer_options["groupnorm_wrapper"](self.out_layers[0], h, transformer_options)
h = self.out_layers[1:](h)
else:
h = self.out_layers(h)
return self.skip_connection(x) + h
@@ -924,6 +941,10 @@ class UNetModel(nn.Module):
if self.predict_codebook_ids:
h = self.id_predictor(h)
elif "groupnorm_wrapper" in transformer_options:
out_norm, out_rest = self.out[0], self.out[1:]
h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options)
h = out_rest(h)
else:
h = self.out(h)

View File

@@ -159,6 +159,10 @@ class UnetPatcher(ModelPatcher):
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
return
def set_groupnorm_wrapper(self, wrapper):
self.set_transformer_option('groupnorm_wrapper', wrapper)
return
def set_controlnet_model_function_wrapper(self, wrapper):
self.set_transformer_option('controlnet_model_function_wrapper', wrapper)
return