mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-02 03:29:49 +00:00
GroupNorm patcher (#593)
* add gn wrapper and corresponding patcher * add gn wrapper and corresponding patcher
This commit is contained in:
committed by
GitHub
parent
0eea1acc6e
commit
5192e912ab
@@ -30,7 +30,7 @@ class TimestepBlock(nn.Module):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x, emb):
|
||||
def forward(self, x, emb, transformer_options={}):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
@@ -46,7 +46,7 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
||||
if isinstance(layer, VideoResBlock):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
x = layer(x, emb, transformer_options)
|
||||
elif isinstance(layer, SpatialVideoTransformer):
|
||||
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
|
||||
if "transformer_index" in transformer_options:
|
||||
@@ -234,7 +234,7 @@ class ResBlock(TimestepBlock):
|
||||
else:
|
||||
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, emb):
|
||||
def forward(self, x, emb, transformer_options={}):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
@@ -242,19 +242,29 @@ class ResBlock(TimestepBlock):
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
return checkpoint(
|
||||
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
||||
self._forward, (x, emb, transformer_options), self.parameters(), self.use_checkpoint
|
||||
)
|
||||
|
||||
|
||||
def _forward(self, x, emb):
|
||||
def _forward(self, x, emb, transformer_options={}):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
if "groupnorm_wrapper" in transformer_options:
|
||||
in_norm, in_rest = in_rest[0], in_rest[1:]
|
||||
h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options)
|
||||
h = in_rest(h)
|
||||
else:
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
if "groupnorm_wrapper" in transformer_options:
|
||||
in_norm = self.in_layers[0]
|
||||
h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options)
|
||||
h = self.in_layers[1:](h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
|
||||
emb_out = None
|
||||
if not self.skip_t_emb:
|
||||
@@ -263,7 +273,10 @@ class ResBlock(TimestepBlock):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
h = out_norm(h)
|
||||
if "groupnorm_wrapper" in transformer_options:
|
||||
h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options)
|
||||
else:
|
||||
h = out_norm(h)
|
||||
if emb_out is not None:
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h *= (1 + scale)
|
||||
@@ -274,7 +287,11 @@ class ResBlock(TimestepBlock):
|
||||
if self.exchange_temb_dims:
|
||||
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
if "groupnorm_wrapper" in transformer_options:
|
||||
h = transformer_options["groupnorm_wrapper"](self.out_layers[0], h, transformer_options)
|
||||
h = self.out_layers[1:](h)
|
||||
else:
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
@@ -924,6 +941,10 @@ class UNetModel(nn.Module):
|
||||
|
||||
if self.predict_codebook_ids:
|
||||
h = self.id_predictor(h)
|
||||
elif "groupnorm_wrapper" in transformer_options:
|
||||
out_norm, out_rest = self.out[0], self.out[1:]
|
||||
h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options)
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = self.out(h)
|
||||
|
||||
|
||||
@@ -159,6 +159,10 @@ class UnetPatcher(ModelPatcher):
|
||||
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def set_groupnorm_wrapper(self, wrapper):
|
||||
self.set_transformer_option('groupnorm_wrapper', wrapper)
|
||||
return
|
||||
|
||||
def set_controlnet_model_function_wrapper(self, wrapper):
|
||||
self.set_transformer_option('controlnet_model_function_wrapper', wrapper)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user