mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Tons of bug fixes and improvements to special training. Fixed slider training.
This commit is contained in:
@@ -2,13 +2,15 @@ from collections import OrderedDict
|
||||
from typing import Union, Literal, List
|
||||
from diffusers import T2IAdapter
|
||||
|
||||
import torch.functional as F
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map, adain, get_mean_std
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
|
||||
apply_learnable_snr_gos, LearnableSNRGamma
|
||||
@@ -35,6 +37,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.assistant_adapter: Union['T2IAdapter', None]
|
||||
self.do_prior_prediction = False
|
||||
self.do_long_prompts = False
|
||||
self.do_guided_loss = False
|
||||
if self.train_config.inverted_mask_prior:
|
||||
self.do_prior_prediction = True
|
||||
|
||||
@@ -186,6 +189,33 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
loss = get_guidance_loss(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
batch=batch,
|
||||
noise=noise,
|
||||
sd=self.sd,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def get_guided_loss_targeted_polarity(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Perform targeted guidance (working title)
|
||||
@@ -194,23 +224,28 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
unconditional_diff = unconditional_latents - conditional_latents
|
||||
conditional_diff = conditional_latents - unconditional_latents
|
||||
mean_latents = (conditional_latents + unconditional_latents) / 2.0
|
||||
|
||||
unconditional_diff = (unconditional_latents - mean_latents)
|
||||
conditional_diff = (conditional_latents - mean_latents)
|
||||
|
||||
# we need to determine the amount of signal and noise that would be present at the current timestep
|
||||
conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
||||
unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
||||
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
|
||||
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
|
||||
|
||||
target_noise = noise + unconditional_signal
|
||||
# target_noise = noise + unconditional_signal
|
||||
|
||||
conditional_noisy_latents = self.sd.add_noise(
|
||||
unconditional_latents + conditional_signal,
|
||||
target_noise,
|
||||
mean_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
unconditional_noisy_latents = self.sd.add_noise(
|
||||
unconditional_latents,
|
||||
mean_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
@@ -229,31 +264,210 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
self.network.multiplier = network_weight_list
|
||||
# double up everything to run it through all at once
|
||||
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0)
|
||||
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
|
||||
# since we are dividing the polarity from the middle out, we need to double our network
|
||||
# weights on training since the convergent point will be at half network strength
|
||||
|
||||
negative_network_weights = [weight * -2.0 for weight in network_weight_list]
|
||||
positive_network_weights = [weight * 2.0 for weight in network_weight_list]
|
||||
cat_network_weight_list = positive_network_weights + negative_network_weights
|
||||
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
|
||||
self.network.multiplier = cat_network_weight_list
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = self.sd.predict_noise(
|
||||
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=cat_timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
# remove baseline from our prediction to extract our differential prediction
|
||||
prediction = prediction - baseline_prediction
|
||||
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
prediction.float(),
|
||||
unconditional_signal.float(),
|
||||
pred_pos = pred_pos - baseline_prediction
|
||||
pred_neg = pred_neg - baseline_prediction
|
||||
|
||||
pred_loss = torch.nn.functional.mse_loss(
|
||||
pred_pos.float(),
|
||||
unconditional_diff.float(),
|
||||
reduction="none"
|
||||
)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
pred_loss = pred_loss.mean([1, 2, 3])
|
||||
|
||||
loss = self.apply_snr(loss, timesteps)
|
||||
pred_neg_loss = torch.nn.functional.mse_loss(
|
||||
pred_neg.float(),
|
||||
conditional_diff.float(),
|
||||
reduction="none"
|
||||
)
|
||||
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
||||
|
||||
loss = (pred_loss + pred_neg_loss) / 2.0
|
||||
|
||||
# loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
loss.requires_grad_(True)
|
||||
|
||||
return loss
|
||||
|
||||
def get_guided_loss_masked_polarity(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Perform targeted guidance (working title)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
||||
inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents)
|
||||
|
||||
mean_latents = (conditional_latents + unconditional_latents) / 2.0
|
||||
|
||||
# unconditional_diff = (unconditional_latents - mean_latents)
|
||||
# conditional_diff = (conditional_latents - mean_latents)
|
||||
|
||||
# we need to determine the amount of signal and noise that would be present at the current timestep
|
||||
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
|
||||
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
|
||||
|
||||
# make a differential mask
|
||||
differential_mask = torch.abs(conditional_latents - unconditional_latents)
|
||||
max_differential = \
|
||||
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
||||
differential_scaler = 1.0 / max_differential
|
||||
differential_mask = differential_mask * differential_scaler
|
||||
spread_point = 0.1
|
||||
# adjust mask to amplify the differential at 0.1
|
||||
differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point
|
||||
# clip it
|
||||
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
|
||||
|
||||
# target_noise = noise + unconditional_signal
|
||||
|
||||
conditional_noisy_latents = self.sd.add_noise(
|
||||
conditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
unconditional_noisy_latents = self.sd.add_noise(
|
||||
unconditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
inverse_noisy_latents = self.sd.add_noise(
|
||||
inverse_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
|
||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||
# This acts as our control to preserve the unaltered parts of the image.
|
||||
# baseline_prediction = self.sd.predict_noise(
|
||||
# latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
# conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
# timestep=timesteps,
|
||||
# guidance_scale=1.0,
|
||||
# **pred_kwargs # adapter residuals in here
|
||||
# ).detach()
|
||||
|
||||
# double up everything to run it through all at once
|
||||
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
|
||||
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
|
||||
# since we are dividing the polarity from the middle out, we need to double our network
|
||||
# weights on training since the convergent point will be at half network strength
|
||||
|
||||
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
|
||||
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
|
||||
cat_network_weight_list = positive_network_weights + negative_network_weights
|
||||
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
|
||||
self.network.multiplier = cat_network_weight_list
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = self.sd.predict_noise(
|
||||
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=cat_timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
||||
|
||||
# create a loss to balance the mean to 0 between the two predictions
|
||||
differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0
|
||||
|
||||
# pred_pos = pred_pos - baseline_prediction
|
||||
# pred_neg = pred_neg - baseline_prediction
|
||||
|
||||
pred_loss = torch.nn.functional.mse_loss(
|
||||
pred_pos.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
# apply mask
|
||||
pred_loss = pred_loss * (1.0 + differential_mask)
|
||||
pred_loss = pred_loss.mean([1, 2, 3])
|
||||
|
||||
pred_neg_loss = torch.nn.functional.mse_loss(
|
||||
pred_neg.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
# apply inverse mask
|
||||
pred_neg_loss = pred_neg_loss * (1.0 - differential_mask)
|
||||
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
||||
|
||||
# make a loss to balance to losses of the pos and neg so they are equal
|
||||
# differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss)
|
||||
#
|
||||
# differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss
|
||||
#
|
||||
# # add a multiplier to balancing losses to make them the top priority
|
||||
# differential_mean_loss = differential_mean_loss
|
||||
|
||||
# remove the grads from the negative as it is only a balancing loss
|
||||
# pred_neg_loss = pred_neg_loss.detach()
|
||||
|
||||
# loss = pred_loss + pred_neg_loss + differential_mean_loss
|
||||
loss = pred_loss + pred_neg_loss
|
||||
|
||||
# loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
|
||||
@@ -556,7 +770,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
if batch.unconditional_latents is not None:
|
||||
if batch.unconditional_latents is not None or self.do_guided_loss:
|
||||
# do guided loss
|
||||
loss = self.get_guided_loss(
|
||||
noisy_latents=noisy_latents,
|
||||
|
||||
Reference in New Issue
Block a user