diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index c9b7178f..a528fd45 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -34,8 +34,6 @@ class SDTrainer(BaseSDTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) network_weight_list = batch.get_network_weight_list() - - self.optimizer.zero_grad() flush() # text encoding @@ -59,11 +57,10 @@ class SDTrainer(BaseSDTrainProcess): with network: with torch.set_grad_enabled(grad_on_text_encoder): conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype) - if not grad_on_text_encoder: - # detach the embeddings - conditional_embeds = conditional_embeds.detach() - self.optimizer.zero_grad() - flush() + # if not grad_on_text_encoder: + # # detach the embeddings + # conditional_embeds = conditional_embeds.detach() + # flush() noise_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), @@ -73,7 +70,7 @@ class SDTrainer(BaseSDTrainProcess): ) flush() # 9.18 gb - noise = noise.to(self.device_torch, dtype=dtype) + noise = noise.to(self.device_torch, dtype=dtype).detach() if self.sd.prediction_type == 'v_prediction': # v-parameterization training diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index d2062912..cb1a6c23 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -719,10 +719,6 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.network is not None and self.network_config.normalize and not self.network.is_normalizing: self.network.is_normalizing = True flush() - ### HOOK ### - loss_dict = self.hook_train_loop(batch) - flush() - # setup the networks to gradient checkpointing and everything works if self.embedding is not None or self.train_config.train_text_encoder: if isinstance(self.sd.text_encoder, list): for te in self.sd.text_encoder: @@ -731,6 +727,10 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.text_encoder.train() self.sd.unet.train() + ### HOOK ### + loss_dict = self.hook_train_loop(batch) + flush() + # setup the networks to gradient checkpointing and everything works with torch.no_grad(): if self.train_config.optimizer.lower().startswith('dadaptation') or \ diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 75d74c6d..7f314748 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -1,10 +1,12 @@ import json import os from collections import OrderedDict -from typing import Optional, Union, List, Type, TYPE_CHECKING +from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any import torch +from diffusers.utils import is_torch_version from torch import nn +from torch.utils.checkpoint import checkpoint from toolkit.metadata import add_model_hash_to_meta from toolkit.paths import KEYMAPS_ROOT @@ -26,6 +28,7 @@ class ToolkitModuleMixin: ): if call_super_init: super().__init__(*args, **kwargs) + self.org_module: torch.nn.Module = kwargs.get('org_module', None) self.is_checkpointing = False self.is_normalizing = False self.normalize_scaler = 1.0 @@ -65,6 +68,8 @@ class ToolkitModuleMixin: return multiplier_tensor.detach() else: + if isinstance(self.multiplier, torch.Tensor): + return self.multiplier.detach() return self.multiplier def _call_forward(self: Module, x): @@ -111,6 +116,7 @@ class ToolkitModuleMixin: return lx * scale def forward(self: Module, x): + x = x.detach() org_forwarded = self.org_forward(x) lora_output = self._call_forward(x) multiplier = self.get_multiplier(lora_output) @@ -236,7 +242,6 @@ class ToolkitNetworkMixin: ): keymap = self.get_keymap() - save_keymap = {} if keymap is not None: for ldm_key, diffusers_key in keymap.items():