FINALLY fixed gradient checkpointing issue. Big batches baby.

This commit is contained in:
Jaret Burkett
2023-09-08 15:21:46 -06:00
parent cb91b0d6da
commit b01ab5d375
3 changed files with 16 additions and 14 deletions

View File

@@ -1,10 +1,12 @@
import json
import os
from collections import OrderedDict
from typing import Optional, Union, List, Type, TYPE_CHECKING
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any
import torch
from diffusers.utils import is_torch_version
from torch import nn
from torch.utils.checkpoint import checkpoint
from toolkit.metadata import add_model_hash_to_meta
from toolkit.paths import KEYMAPS_ROOT
@@ -26,6 +28,7 @@ class ToolkitModuleMixin:
):
if call_super_init:
super().__init__(*args, **kwargs)
self.org_module: torch.nn.Module = kwargs.get('org_module', None)
self.is_checkpointing = False
self.is_normalizing = False
self.normalize_scaler = 1.0
@@ -65,6 +68,8 @@ class ToolkitModuleMixin:
return multiplier_tensor.detach()
else:
if isinstance(self.multiplier, torch.Tensor):
return self.multiplier.detach()
return self.multiplier
def _call_forward(self: Module, x):
@@ -111,6 +116,7 @@ class ToolkitModuleMixin:
return lx * scale
def forward(self: Module, x):
x = x.detach()
org_forwarded = self.org_forward(x)
lora_output = self._call_forward(x)
multiplier = self.get_multiplier(lora_output)
@@ -236,7 +242,6 @@ class ToolkitNetworkMixin:
):
keymap = self.get_keymap()
save_keymap = {}
if keymap is not None:
for ldm_key, diffusers_key in keymap.items():