Tons of bug fixes and improvements to special training. Fixed slider training.

This commit is contained in:
Jaret Burkett
2023-12-09 16:38:10 -07:00
parent eaec2f5a52
commit eaa0fb6253
9 changed files with 639 additions and 74 deletions

View File

@@ -2,13 +2,15 @@ from collections import OrderedDict
from typing import Union, Literal, List
from diffusers import T2IAdapter
import torch.functional as F
from toolkit import train_tools
from toolkit.basic import value_map, adain, get_mean_std
from toolkit.config_modules import GuidanceConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import PromptEmbeds
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
apply_learnable_snr_gos, LearnableSNRGamma
@@ -35,6 +37,7 @@ class SDTrainer(BaseSDTrainProcess):
self.assistant_adapter: Union['T2IAdapter', None]
self.do_prior_prediction = False
self.do_long_prompts = False
self.do_guided_loss = False
if self.train_config.inverted_mask_prior:
self.do_prior_prediction = True
@@ -186,6 +189,33 @@ class SDTrainer(BaseSDTrainProcess):
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
loss = get_guidance_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
sd=self.sd,
**kwargs
)
return loss
def get_guided_loss_targeted_polarity(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
@@ -194,23 +224,28 @@ class SDTrainer(BaseSDTrainProcess):
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
unconditional_diff = unconditional_latents - conditional_latents
conditional_diff = conditional_latents - unconditional_latents
mean_latents = (conditional_latents + unconditional_latents) / 2.0
unconditional_diff = (unconditional_latents - mean_latents)
conditional_diff = (conditional_latents - mean_latents)
# we need to determine the amount of signal and noise that would be present at the current timestep
conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
target_noise = noise + unconditional_signal
# target_noise = noise + unconditional_signal
conditional_noisy_latents = self.sd.add_noise(
unconditional_latents + conditional_signal,
target_noise,
mean_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = self.sd.add_noise(
unconditional_latents,
mean_latents,
noise,
timesteps
).detach()
@@ -229,31 +264,210 @@ class SDTrainer(BaseSDTrainProcess):
**pred_kwargs # adapter residuals in here
).detach()
# turn the LoRA network back on.
self.sd.unet.train()
self.network.is_active = True
self.network.multiplier = network_weight_list
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
# since we are dividing the polarity from the middle out, we need to double our network
# weights on training since the convergent point will be at half network strength
negative_network_weights = [weight * -2.0 for weight in network_weight_list]
positive_network_weights = [weight * 2.0 for weight in network_weight_list]
cat_network_weight_list = positive_network_weights + negative_network_weights
# turn the LoRA network back on.
self.sd.unet.train()
self.network.is_active = True
self.network.multiplier = cat_network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = self.sd.predict_noise(
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
# remove baseline from our prediction to extract our differential prediction
prediction = prediction - baseline_prediction
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
loss = torch.nn.functional.mse_loss(
prediction.float(),
unconditional_signal.float(),
pred_pos = pred_pos - baseline_prediction
pred_neg = pred_neg - baseline_prediction
pred_loss = torch.nn.functional.mse_loss(
pred_pos.float(),
unconditional_diff.float(),
reduction="none"
)
loss = loss.mean([1, 2, 3])
pred_loss = pred_loss.mean([1, 2, 3])
loss = self.apply_snr(loss, timesteps)
pred_neg_loss = torch.nn.functional.mse_loss(
pred_neg.float(),
conditional_diff.float(),
reduction="none"
)
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
loss = (pred_loss + pred_neg_loss) / 2.0
# loss = self.apply_snr(loss, timesteps)
loss = loss.mean()
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
def get_guided_loss_masked_polarity(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
dtype = get_torch_dtype(self.train_config.dtype)
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents)
mean_latents = (conditional_latents + unconditional_latents) / 2.0
# unconditional_diff = (unconditional_latents - mean_latents)
# conditional_diff = (conditional_latents - mean_latents)
# we need to determine the amount of signal and noise that would be present at the current timestep
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
# make a differential mask
differential_mask = torch.abs(conditional_latents - unconditional_latents)
max_differential = \
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
differential_scaler = 1.0 / max_differential
differential_mask = differential_mask * differential_scaler
spread_point = 0.1
# adjust mask to amplify the differential at 0.1
differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point
# clip it
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
# target_noise = noise + unconditional_signal
conditional_noisy_latents = self.sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = self.sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
inverse_noisy_latents = self.sd.add_noise(
inverse_latents,
noise,
timesteps
).detach()
# Disable the LoRA network so we can predict parent network knowledge without it
self.network.is_active = False
self.sd.unet.eval()
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
# baseline_prediction = self.sd.predict_noise(
# latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
# conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
# timestep=timesteps,
# guidance_scale=1.0,
# **pred_kwargs # adapter residuals in here
# ).detach()
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
# since we are dividing the polarity from the middle out, we need to double our network
# weights on training since the convergent point will be at half network strength
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
cat_network_weight_list = positive_network_weights + negative_network_weights
# turn the LoRA network back on.
self.sd.unet.train()
self.network.is_active = True
self.network.multiplier = cat_network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = self.sd.predict_noise(
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
# create a loss to balance the mean to 0 between the two predictions
differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0
# pred_pos = pred_pos - baseline_prediction
# pred_neg = pred_neg - baseline_prediction
pred_loss = torch.nn.functional.mse_loss(
pred_pos.float(),
noise.float(),
reduction="none"
)
# apply mask
pred_loss = pred_loss * (1.0 + differential_mask)
pred_loss = pred_loss.mean([1, 2, 3])
pred_neg_loss = torch.nn.functional.mse_loss(
pred_neg.float(),
noise.float(),
reduction="none"
)
# apply inverse mask
pred_neg_loss = pred_neg_loss * (1.0 - differential_mask)
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
# make a loss to balance to losses of the pos and neg so they are equal
# differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss)
#
# differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss
#
# # add a multiplier to balancing losses to make them the top priority
# differential_mean_loss = differential_mean_loss
# remove the grads from the negative as it is only a balancing loss
# pred_neg_loss = pred_neg_loss.detach()
# loss = pred_loss + pred_neg_loss + differential_mean_loss
loss = pred_loss + pred_neg_loss
# loss = self.apply_snr(loss, timesteps)
loss = loss.mean()
loss.backward()
@@ -556,7 +770,7 @@ class SDTrainer(BaseSDTrainProcess):
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None:
if batch.unconditional_latents is not None or self.do_guided_loss:
# do guided loss
loss = self.get_guided_loss(
noisy_latents=noisy_latents,