mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Switched to trailing timestep spacing to make timesteps for consistant across schedulers. Honed in on targeted guidance. It is finally perfect. (I think)
This commit is contained in:
@@ -189,40 +189,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Perform targeted guidance (working title)
|
||||
conditional_noisy_latents = noisy_latents # target images
|
||||
conditional_noisy_latents = noisy_latents.detach() # target images
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
if batch.unconditional_latents is not None:
|
||||
# unconditional latents are the "neutral" images. Add noise here identical to
|
||||
# the noise added to the conditional latents, at the same timesteps
|
||||
# unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
|
||||
# batch.unconditional_latents, noise, timesteps
|
||||
# )
|
||||
unconditional_noisy_latents = self.sd.add_noise(batch.unconditional_latents, noise, timesteps)
|
||||
|
||||
# calculate the differential between our conditional (target image) and out unconditional (neutral image)
|
||||
target_differential_noise = unconditional_noisy_latents - conditional_noisy_latents
|
||||
target_differential_noise = target_differential_noise.detach()
|
||||
unconditional_noisy_latents = self.sd.add_noise(batch.unconditional_latents, noise, timesteps).detach()
|
||||
|
||||
# Calculate the mean along dim=1, keep dimensions
|
||||
mean_chan = torch.abs(torch.mean(target_differential_noise, dim=1, keepdim=True))
|
||||
|
||||
# Create a mask with 0s where values are between 0.0 and 0.01, otherwise 1s
|
||||
mask = torch.where((mean_chan >= 0.0) & (mean_chan <= 0.01), 0.0, 1.0)
|
||||
|
||||
# Duplicate the mask along dim 1 to match the shape of target_differential_noise
|
||||
mask = mask.expand_as(target_differential_noise)
|
||||
# this mask is now a 1 for our target differential and 0 for everything else
|
||||
|
||||
# add the target differential to the target latents as if it were noise with the scheduler, scaled to
|
||||
# the current timestep. Scaling the noise here is important as it scales our guidance to the current
|
||||
# timestep. This is the key to making the guidance work.
|
||||
# guidance_latents = self.sd.noise_scheduler.add_noise(
|
||||
# conditional_noisy_latents,
|
||||
# target_differential_noise,
|
||||
# timesteps
|
||||
# )
|
||||
guidance_latents = self.sd.add_noise(conditional_noisy_latents, target_differential_noise, timesteps)
|
||||
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
self.network.is_active = False
|
||||
@@ -231,7 +206,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||
# This acts as our control to preserve the unaltered parts of the image.
|
||||
baseline_prediction = self.sd.predict_noise(
|
||||
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
@@ -245,42 +220,19 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = self.sd.predict_noise(
|
||||
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
# remove the baseline prediction from our prediction to get the differential between the two
|
||||
# all that should be left is the differential between the conditional and unconditional images
|
||||
pred_differential_noise = prediction - baseline_prediction
|
||||
|
||||
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
|
||||
# not the timestep scaled noise that was added. This is the diffusion training process.
|
||||
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
|
||||
# differential between the conditional and unconditional images (target)
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
pred_differential_noise.float(),
|
||||
target_differential_noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
|
||||
# multiply by our mask
|
||||
loss = loss * mask
|
||||
loss = loss.mean([1, 2, 3])
|
||||
# calculate inverse to match baseline prediction
|
||||
unmasked_prior_loss = torch.nn.functional.mse_loss(
|
||||
baseline_prediction.float(),
|
||||
prediction.float(),
|
||||
baseline_prediction.float(),
|
||||
reduction="none"
|
||||
)
|
||||
# multiply by our mask
|
||||
unmasked_prior_loss = unmasked_prior_loss * (1.0 - mask)
|
||||
# add the unmasked prior loss to the masked loss
|
||||
unmasked_prior_loss = unmasked_prior_loss.mean([1, 2, 3])
|
||||
loss = loss + unmasked_prior_loss
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
|
||||
Reference in New Issue
Block a user