mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Bug fixes
This commit is contained in:
@@ -403,11 +403,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
|
||||
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
|
||||
|
||||
transformer.pos_embed.orig_forward = transformer.pos_embed.forward
|
||||
transformer.proj_out.orig_forward = transformer.proj_out.forward
|
||||
|
||||
transformer.pos_embed.forward = self.transformer_pos_embed.forward
|
||||
transformer.proj_out.forward = self.transformer_proj_out.forward
|
||||
transformer.pos_embed = self.transformer_pos_embed
|
||||
transformer.proj_out = self.transformer_proj_out
|
||||
|
||||
else:
|
||||
unet: UNet2DConditionModel = unet
|
||||
@@ -417,10 +414,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
# clone these and replace their forwards with ours
|
||||
self.unet_conv_in = copy.deepcopy(unet_conv_in)
|
||||
self.unet_conv_out = copy.deepcopy(unet_conv_out)
|
||||
unet.conv_in.orig_forward = unet_conv_in.forward
|
||||
unet_conv_out.orig_forward = unet_conv_out.forward
|
||||
unet.conv_in.forward = self.unet_conv_in.forward
|
||||
unet.conv_out.forward = self.unet_conv_out.forward
|
||||
unet.conv_in = self.unet_conv_in
|
||||
unet.conv_out = self.unet_conv_out
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
# call Lora prepare_optimizer_params
|
||||
|
||||
Reference in New Issue
Block a user