actually got gradient checkpointing working, again, again, maybe

This commit is contained in:
Jaret Burkett
2023-09-09 11:27:42 -06:00
parent 4ed03a8d92
commit 408c50ead1
5 changed files with 102 additions and 70 deletions

View File

@@ -29,17 +29,19 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
lora_dim=4, alpha=1,
dropout=0., rank_dropout=0., module_dropout=0.,
use_cp=False,
parent=None,
**kwargs,
):
""" if alpha == 0 or None, alpha is rank (no scaling). """
# call super of super
torch.nn.Module.__init__(self)
# call super of
super().__init__(
org_module=org_module,
call_super_init=False,
parent=parent,
**kwargs
)
# call super of super
super(LoConModule, self).__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
self.cp = False
@@ -163,7 +165,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
**kwargs
)
# call the parent of the parent LycorisNetwork
super(LycorisNetwork, self).__init__()
torch.nn.Module.__init__(self)
# LyCORIS unique stuff
if dropout is None:
@@ -176,8 +178,9 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.multiplier = multiplier
self.lora_dim = lora_dim
if not self.ENABLE_CONV:
if not self.ENABLE_CONV or conv_lora_dim is None:
conv_lora_dim = 0
conv_alpha = 0
self.conv_lora_dim = int(conv_lora_dim)
if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim:
@@ -231,6 +234,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
parent=module,
**kwargs
)
elif child_module.__class__.__name__ in CONV_MODULES:
@@ -241,6 +245,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
parent=module,
**kwargs
)
elif conv_lora_dim > 0:
@@ -249,6 +254,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.conv_lora_dim, self.conv_alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
parent=module,
**kwargs
)
else:
@@ -269,6 +275,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
parent=module,
**kwargs
)
elif module.__class__.__name__ == 'Conv2d':
@@ -279,6 +286,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
parent=module,
**kwargs
)
elif conv_lora_dim > 0:
@@ -287,6 +295,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.conv_lora_dim, self.conv_alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
parent=module,
**kwargs
)
else: