actually got gradient checkpointing working, again, again, maybe

This commit is contained in:
Jaret Burkett
2023-09-09 11:27:42 -06:00
parent 4ed03a8d92
commit 408c50ead1
5 changed files with 102 additions and 70 deletions

View File

@@ -46,9 +46,14 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
dropout=None,
rank_dropout=None,
module_dropout=None,
parent=None,
**kwargs
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
super().__init__(
org_module=org_module,
parent=parent
)
self.lora_name = lora_name
self.scalar = torch.tensor(1.0)
@@ -256,6 +261,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
parent=module,
)
loras.append(lora)
return loras, skipped