mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
FINALLY fixed gradient checkpointing issue. Big batches baby.
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Union, List, Type, TYPE_CHECKING
|
||||
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any
|
||||
|
||||
import torch
|
||||
from diffusers.utils import is_torch_version
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
@@ -26,6 +28,7 @@ class ToolkitModuleMixin:
|
||||
):
|
||||
if call_super_init:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.org_module: torch.nn.Module = kwargs.get('org_module', None)
|
||||
self.is_checkpointing = False
|
||||
self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
@@ -65,6 +68,8 @@ class ToolkitModuleMixin:
|
||||
return multiplier_tensor.detach()
|
||||
|
||||
else:
|
||||
if isinstance(self.multiplier, torch.Tensor):
|
||||
return self.multiplier.detach()
|
||||
return self.multiplier
|
||||
|
||||
def _call_forward(self: Module, x):
|
||||
@@ -111,6 +116,7 @@ class ToolkitModuleMixin:
|
||||
return lx * scale
|
||||
|
||||
def forward(self: Module, x):
|
||||
x = x.detach()
|
||||
org_forwarded = self.org_forward(x)
|
||||
lora_output = self._call_forward(x)
|
||||
multiplier = self.get_multiplier(lora_output)
|
||||
@@ -236,7 +242,6 @@ class ToolkitNetworkMixin:
|
||||
):
|
||||
keymap = self.get_keymap()
|
||||
|
||||
|
||||
save_keymap = {}
|
||||
if keymap is not None:
|
||||
for ldm_key, diffusers_key in keymap.items():
|
||||
|
||||
Reference in New Issue
Block a user