mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
actually got gradient checkpointing working, again, again, maybe
This commit is contained in:
@@ -29,17 +29,19 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
|
||||
lora_dim=4, alpha=1,
|
||||
dropout=0., rank_dropout=0., module_dropout=0.,
|
||||
use_cp=False,
|
||||
parent=None,
|
||||
**kwargs,
|
||||
):
|
||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||
# call super of super
|
||||
torch.nn.Module.__init__(self)
|
||||
# call super of
|
||||
super().__init__(
|
||||
org_module=org_module,
|
||||
call_super_init=False,
|
||||
parent=parent,
|
||||
**kwargs
|
||||
)
|
||||
# call super of super
|
||||
super(LoConModule, self).__init__()
|
||||
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
self.cp = False
|
||||
@@ -163,7 +165,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
**kwargs
|
||||
)
|
||||
# call the parent of the parent LycorisNetwork
|
||||
super(LycorisNetwork, self).__init__()
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
# LyCORIS unique stuff
|
||||
if dropout is None:
|
||||
@@ -176,8 +178,9 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if not self.ENABLE_CONV:
|
||||
if not self.ENABLE_CONV or conv_lora_dim is None:
|
||||
conv_lora_dim = 0
|
||||
conv_alpha = 0
|
||||
|
||||
self.conv_lora_dim = int(conv_lora_dim)
|
||||
if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim:
|
||||
@@ -231,6 +234,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
parent=module,
|
||||
**kwargs
|
||||
)
|
||||
elif child_module.__class__.__name__ in CONV_MODULES:
|
||||
@@ -241,6 +245,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
parent=module,
|
||||
**kwargs
|
||||
)
|
||||
elif conv_lora_dim > 0:
|
||||
@@ -249,6 +254,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.conv_lora_dim, self.conv_alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
parent=module,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
@@ -269,6 +275,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
parent=module,
|
||||
**kwargs
|
||||
)
|
||||
elif module.__class__.__name__ == 'Conv2d':
|
||||
@@ -279,6 +286,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
parent=module,
|
||||
**kwargs
|
||||
)
|
||||
elif conv_lora_dim > 0:
|
||||
@@ -287,6 +295,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.conv_lora_dim, self.conv_alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
parent=module,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user