FINALLY fixed gradient checkpointing issue. Big batches baby.

This commit is contained in:
Jaret Burkett
2023-09-08 15:21:46 -06:00
parent cb91b0d6da
commit b01ab5d375
3 changed files with 16 additions and 14 deletions

View File

@@ -34,8 +34,6 @@ class SDTrainer(BaseSDTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list()
self.optimizer.zero_grad()
flush()
# text encoding
@@ -59,11 +57,10 @@ class SDTrainer(BaseSDTrainProcess):
with network:
with torch.set_grad_enabled(grad_on_text_encoder):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
if not grad_on_text_encoder:
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
self.optimizer.zero_grad()
flush()
# if not grad_on_text_encoder:
# # detach the embeddings
# conditional_embeds = conditional_embeds.detach()
# flush()
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
@@ -73,7 +70,7 @@ class SDTrainer(BaseSDTrainProcess):
)
flush()
# 9.18 gb
noise = noise.to(self.device_torch, dtype=dtype)
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training

View File

@@ -719,10 +719,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.network is not None and self.network_config.normalize and not self.network.is_normalizing:
self.network.is_normalizing = True
flush()
### HOOK ###
loss_dict = self.hook_train_loop(batch)
flush()
# setup the networks to gradient checkpointing and everything works
if self.embedding is not None or self.train_config.train_text_encoder:
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
@@ -731,6 +727,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.text_encoder.train()
self.sd.unet.train()
### HOOK ###
loss_dict = self.hook_train_loop(batch)
flush()
# setup the networks to gradient checkpointing and everything works
with torch.no_grad():
if self.train_config.optimizer.lower().startswith('dadaptation') or \

View File

@@ -1,10 +1,12 @@
import json
import os
from collections import OrderedDict
from typing import Optional, Union, List, Type, TYPE_CHECKING
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any
import torch
from diffusers.utils import is_torch_version
from torch import nn
from torch.utils.checkpoint import checkpoint
from toolkit.metadata import add_model_hash_to_meta
from toolkit.paths import KEYMAPS_ROOT
@@ -26,6 +28,7 @@ class ToolkitModuleMixin:
):
if call_super_init:
super().__init__(*args, **kwargs)
self.org_module: torch.nn.Module = kwargs.get('org_module', None)
self.is_checkpointing = False
self.is_normalizing = False
self.normalize_scaler = 1.0
@@ -65,6 +68,8 @@ class ToolkitModuleMixin:
return multiplier_tensor.detach()
else:
if isinstance(self.multiplier, torch.Tensor):
return self.multiplier.detach()
return self.multiplier
def _call_forward(self: Module, x):
@@ -111,6 +116,7 @@ class ToolkitModuleMixin:
return lx * scale
def forward(self: Module, x):
x = x.detach()
org_forwarded = self.org_forward(x)
lora_output = self._call_forward(x)
multiplier = self.get_multiplier(lora_output)
@@ -236,7 +242,6 @@ class ToolkitNetworkMixin:
):
keymap = self.get_keymap()
save_keymap = {}
if keymap is not None:
for ldm_key, diffusers_key in keymap.items():