mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
More fixes for noise schedules and fixed targeted guidance inverted masked prior
This commit is contained in:
@@ -195,22 +195,34 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if batch.unconditional_latents is not None:
|
||||
# unconditional latents are the "neutral" images. Add noise here identical to
|
||||
# the noise added to the conditional latents, at the same timesteps
|
||||
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
|
||||
batch.unconditional_latents, noise, timesteps
|
||||
)
|
||||
# unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
|
||||
# batch.unconditional_latents, noise, timesteps
|
||||
# )
|
||||
unconditional_noisy_latents = self.sd.add_noise(batch.unconditional_latents, noise, timesteps)
|
||||
|
||||
# calculate the differential between our conditional (target image) and out unconditional (neutral image)
|
||||
target_differential_noise = unconditional_noisy_latents - conditional_noisy_latents
|
||||
target_differential_noise = target_differential_noise.detach()
|
||||
|
||||
# Calculate the mean along dim=1, keep dimensions
|
||||
mean_chan = torch.abs(torch.mean(target_differential_noise, dim=1, keepdim=True))
|
||||
|
||||
# Create a mask with 0s where values are between 0.0 and 0.01, otherwise 1s
|
||||
mask = torch.where((mean_chan >= 0.0) & (mean_chan <= 0.01), 0.0, 1.0)
|
||||
|
||||
# Duplicate the mask along dim 1 to match the shape of target_differential_noise
|
||||
mask = mask.expand_as(target_differential_noise)
|
||||
# this mask is now a 1 for our target differential and 0 for everything else
|
||||
|
||||
# add the target differential to the target latents as if it were noise with the scheduler, scaled to
|
||||
# the current timestep. Scaling the noise here is important as it scales our guidance to the current
|
||||
# timestep. This is the key to making the guidance work.
|
||||
guidance_latents = self.sd.noise_scheduler.add_noise(
|
||||
conditional_noisy_latents,
|
||||
target_differential_noise,
|
||||
timesteps
|
||||
)
|
||||
# guidance_latents = self.sd.noise_scheduler.add_noise(
|
||||
# conditional_noisy_latents,
|
||||
# target_differential_noise,
|
||||
# timesteps
|
||||
# )
|
||||
guidance_latents = self.sd.add_noise(conditional_noisy_latents, target_differential_noise, timesteps)
|
||||
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
self.network.is_active = False
|
||||
@@ -254,7 +266,22 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
reduction="none"
|
||||
)
|
||||
|
||||
# multiply by our mask
|
||||
loss = loss * mask
|
||||
loss = loss.mean([1, 2, 3])
|
||||
# calculate inverse to match baseline prediction
|
||||
unmasked_prior_loss = torch.nn.functional.mse_loss(
|
||||
baseline_prediction.float(),
|
||||
prediction.float(),
|
||||
reduction="none"
|
||||
)
|
||||
# multiply by our mask
|
||||
unmasked_prior_loss = unmasked_prior_loss * (1.0 - mask)
|
||||
# add the unmasked prior loss to the masked loss
|
||||
unmasked_prior_loss = unmasked_prior_loss.mean([1, 2, 3])
|
||||
loss = loss + unmasked_prior_loss
|
||||
|
||||
|
||||
loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
|
||||
@@ -775,7 +775,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if batch.unconditional_latents is not None:
|
||||
batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier
|
||||
|
||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
noisy_latents = self.sd.add_noise(latents, noise, timesteps)
|
||||
|
||||
# determine scaled noise
|
||||
# todo do we need to scale this or does it always predict full intensity
|
||||
|
||||
@@ -643,6 +643,80 @@ class StableDiffusion:
|
||||
else:
|
||||
return None
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor
|
||||
) -> torch.FloatTensor:
|
||||
# we handle adding noise for the various schedulers here. Some
|
||||
# schedulers reference timesteps while others reference idx
|
||||
# so we need to handle both cases
|
||||
# get scheduler class name
|
||||
scheduler_class_name = self.noise_scheduler.__class__.__name__
|
||||
|
||||
index_noise_schedulers = [
|
||||
'DPMSolverMultistepScheduler',
|
||||
'EulerDiscreteSchedulerOutput',
|
||||
]
|
||||
|
||||
|
||||
# todo handle if timestep is single value
|
||||
|
||||
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
||||
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||
timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
||||
|
||||
if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
|
||||
timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
|
||||
|
||||
noisy_latents_chunks = []
|
||||
|
||||
for idx in range(original_samples.shape[0]):
|
||||
|
||||
if scheduler_class_name not in index_noise_schedulers:
|
||||
# convert to idx
|
||||
noise_timesteps = [(self.noise_scheduler.timesteps == t).nonzero().item() for t in timesteps_chunks[idx]]
|
||||
noise_timesteps = torch.tensor(noise_timesteps, device=self.device_torch)
|
||||
else:
|
||||
noise_timesteps = timesteps_chunks[idx]
|
||||
|
||||
# the add noise for ddpm solver is broken, do it ourselves
|
||||
if scheduler_class_name == 'DPMSolverMultistepScheduler':
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.noise_scheduler.sigmas.to(device=original_samples_chunks[idx].device, dtype=original_samples_chunks[idx].dtype)
|
||||
if original_samples_chunks[idx].device.type == "mps" and torch.is_floating_point(noise_timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.noise_scheduler.timesteps.to(original_samples_chunks[idx].device, dtype=torch.float32)
|
||||
noise_timesteps = noise_timesteps.to(original_samples_chunks[idx].device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.noise_scheduler.timesteps.to(original_samples_chunks[idx].device)
|
||||
noise_timesteps = noise_timesteps.to(original_samples_chunks[idx].device)
|
||||
|
||||
step_indices = []
|
||||
for t in noise_timesteps:
|
||||
for i, st in enumerate(schedule_timesteps):
|
||||
if st == t:
|
||||
step_indices.append(i)
|
||||
break
|
||||
|
||||
# find only first match. There can be double here, this breaks
|
||||
# step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
alpha_t, sigma_t = self.noise_scheduler._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise_chunks[idx]
|
||||
noisy_latents = noisy_samples
|
||||
else:
|
||||
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], noise_timesteps)
|
||||
noisy_latents_chunks.append(noisy_latents)
|
||||
|
||||
noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||
return noisy_latents
|
||||
|
||||
def predict_noise(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
|
||||
@@ -776,7 +776,14 @@ def apply_snr_weight(
|
||||
):
|
||||
# will get it from noise scheduler if exist or will calculate it if not
|
||||
all_snr = get_all_snr(noise_scheduler, loss.device)
|
||||
step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
|
||||
step_indices = []
|
||||
for t in timesteps:
|
||||
for i, st in enumerate(noise_scheduler.timesteps):
|
||||
if st == t:
|
||||
step_indices.append(i)
|
||||
break
|
||||
# this breaks on some schedulers
|
||||
# step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
|
||||
snr = torch.stack([all_snr[t] for t in step_indices])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
if fixed:
|
||||
|
||||
Reference in New Issue
Block a user