Bug fixes

This commit is contained in:
Jaret Burkett
2024-07-03 10:56:34 -06:00
parent bb57623a35
commit acb06d6ff3
6 changed files with 133 additions and 10 deletions

View File

@@ -403,11 +403,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
transformer.pos_embed.orig_forward = transformer.pos_embed.forward
transformer.proj_out.orig_forward = transformer.proj_out.forward
transformer.pos_embed.forward = self.transformer_pos_embed.forward
transformer.proj_out.forward = self.transformer_proj_out.forward
transformer.pos_embed = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out
else:
unet: UNet2DConditionModel = unet
@@ -417,10 +414,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
# clone these and replace their forwards with ours
self.unet_conv_in = copy.deepcopy(unet_conv_in)
self.unet_conv_out = copy.deepcopy(unet_conv_out)
unet.conv_in.orig_forward = unet_conv_in.forward
unet_conv_out.orig_forward = unet_conv_out.forward
unet.conv_in.forward = self.unet_conv_in.forward
unet.conv_out.forward = self.unet_conv_out.forward
unet.conv_in = self.unet_conv_in
unet.conv_out = self.unet_conv_out
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
# call Lora prepare_optimizer_params