mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-03 17:49:49 +00:00
Diffirential guidance working, but I may have a better way
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Union
|
||||
from diffusers import T2IAdapter
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.basic import value_map, adain
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
@@ -202,20 +202,39 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# calculate the differential between our conditional (target image) and out unconditional ("bad" image)
|
||||
target_differential = unconditional_noisy_latents - conditional_noisy_latents
|
||||
|
||||
# scale the target differential by the scheduler
|
||||
# todo, scale it the right way
|
||||
# target_differential = self.sd.noise_scheduler.add_noise(
|
||||
# torch.zeros_like(target_differential),
|
||||
# target_differential,
|
||||
# timesteps
|
||||
# )
|
||||
|
||||
target_differential = target_differential.detach()
|
||||
|
||||
# add the target differential to the target latents as if it were noise with the scheduler scaled to
|
||||
# the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done
|
||||
# properly
|
||||
guidance_latents = self.sd.noise_scheduler.add_noise(
|
||||
conditional_noisy_latents,
|
||||
target_differential,
|
||||
timesteps
|
||||
)
|
||||
# guidance_latents = self.sd.noise_scheduler.add_noise(
|
||||
# conditional_noisy_latents,
|
||||
# target_differential,
|
||||
# timesteps
|
||||
# )
|
||||
|
||||
# guidance_latents = conditional_noisy_latents + target_differential
|
||||
# target_noise = conditional_noisy_latents + target_differential
|
||||
|
||||
# With LoRA network bypassed, predict noise to get a baseline of what the network
|
||||
# wants to do with the latents + noise. Pass our target latents here for the input.
|
||||
target_unconditional = self.sd.predict_noise(
|
||||
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
target_conditional = self.sd.predict_noise(
|
||||
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
@@ -223,6 +242,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
|
||||
# we calculate the networks current knowledge so we do not overlearn what we know
|
||||
current_knowledge = target_unconditional - target_conditional
|
||||
|
||||
# we now have the differential noise prediction needed to create our convergence target
|
||||
target_unknown_knowledge = target_differential - current_knowledge
|
||||
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
@@ -231,7 +256,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# with LoRA active, predict the noise with the scaled differential latents added. This will allow us
|
||||
# the opportunity to predict the differential + noise that was added to the latents.
|
||||
prediction_unconditional = self.sd.predict_noise(
|
||||
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
@@ -240,18 +265,26 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# remove the baseline conditional prediction. This will leave only the divergence from the baseline and
|
||||
# the prediction of the added differential noise
|
||||
prediction_positive = prediction_unconditional - target_unconditional
|
||||
# prediction_positive = prediction_unconditional - target_unconditional
|
||||
prediction_positive = target_unconditional - prediction_unconditional
|
||||
|
||||
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
|
||||
# this is the diffusion training process.
|
||||
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
|
||||
# differential between the conditional and unconditional images
|
||||
|
||||
positive_loss = torch.nn.functional.mse_loss(
|
||||
prediction_positive.float(),
|
||||
target_differential.float(),
|
||||
target_unknown_knowledge.float(),
|
||||
reduction="none"
|
||||
)
|
||||
|
||||
# add adain loss
|
||||
positive_loss = positive_loss
|
||||
|
||||
positive_loss = positive_loss.mean([1, 2, 3])
|
||||
|
||||
# positive_loss = positive_loss + adain_loss.mean([1, 2, 3])
|
||||
# send it backwards BEFORE switching network polarity
|
||||
positive_loss = self.apply_snr(positive_loss, timesteps)
|
||||
positive_loss = positive_loss.mean()
|
||||
|
||||
@@ -11,3 +11,25 @@ def flush(garbage_collect=True):
|
||||
torch.cuda.empty_cache()
|
||||
if garbage_collect:
|
||||
gc.collect()
|
||||
|
||||
|
||||
def adain(content_features, style_features):
|
||||
# Assumes that the content and style features are of shape (batch_size, channels, width, height)
|
||||
|
||||
# Step 1: Calculate mean and variance of content features
|
||||
content_mean, content_var = torch.mean(content_features, dim=[2, 3], keepdim=True), torch.var(content_features,
|
||||
dim=[2, 3],
|
||||
keepdim=True)
|
||||
# Step 2: Calculate mean and variance of style features
|
||||
style_mean, style_var = torch.mean(style_features, dim=[2, 3], keepdim=True), torch.var(style_features, dim=[2, 3],
|
||||
keepdim=True)
|
||||
|
||||
# Step 3: Normalize content features
|
||||
content_std = torch.sqrt(content_var + 1e-5)
|
||||
normalized_content = (content_features - content_mean) / content_std
|
||||
|
||||
# Step 4: Scale and shift normalized content with style's statistics
|
||||
style_std = torch.sqrt(style_var + 1e-5)
|
||||
stylized_content = normalized_content * style_std + style_mean
|
||||
|
||||
return stylized_content
|
||||
|
||||
Reference in New Issue
Block a user