Adjusted flow matching so target noise multiplier works properly with it.

This commit is contained in:
Jaret Burkett
2024-08-05 11:40:05 -06:00
parent 0ea27011d5
commit edb7e827ee
4 changed files with 35 additions and 26 deletions

View File

@@ -31,7 +31,7 @@ from jobs.process import BaseSDTrainProcess
from torchvision import transforms
from diffusers import EMAModel
import math
from toolkit.train_tools import precondition_model_outputs_flow_match
def flush():
@@ -328,14 +328,24 @@ class SDTrainer(BaseSDTrainProcess):
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
elif self.sd.is_rectified_flow:
elif self.sd.is_flow_matching:
# only if preconditioning model outputs
# if not preconditioning, (target = noise - batch.latents)
# target = noise - batch.latents
# if preconditioning outputs, target latents
# model_pred = model_pred * (-sigmas) + noisy_model_input
if self.train_config.target_noise_multiplier != 1.0:
# we are adjusting the target noise, need to recompute the noisy latents with
# the noise adjusted above
with torch.no_grad():
noisy_latents = self.sd.add_noise(batch.latents, noise, timesteps).detach()
noise_pred = precondition_model_outputs_flow_match(
noise_pred,
noisy_latents,
timesteps,
self.sd.noise_scheduler
)
target = batch.latents.detach()
else:
target = noise
@@ -383,7 +393,7 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss_per_element
else:
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
if self.sd.is_rectified_flow and prior_pred is None:
if self.sd.is_flow_matching and prior_pred is None:
# outputs should be preprocessed latents
sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch)
weighting = torch.ones_like(sigmas)