mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Updates to flow matching algo
This commit is contained in:
@@ -329,24 +329,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
|
||||
|
||||
elif self.sd.is_flow_matching:
|
||||
# only if preconditioning model outputs
|
||||
# if not preconditioning, (target = noise - batch.latents)
|
||||
|
||||
# if preconditioning outputs, target latents
|
||||
# model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
if self.train_config.target_noise_multiplier != 1.0:
|
||||
# we are adjusting the target noise, need to recompute the noisy latents with
|
||||
# the noise adjusted above
|
||||
with torch.no_grad():
|
||||
noisy_latents = self.sd.add_noise(batch.latents, noise, timesteps).detach()
|
||||
|
||||
noise_pred = precondition_model_outputs_flow_match(
|
||||
noise_pred,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
self.sd.noise_scheduler
|
||||
)
|
||||
target = batch.latents.detach()
|
||||
target = (noise - batch.latents).detach()
|
||||
else:
|
||||
target = noise
|
||||
|
||||
@@ -392,14 +375,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
|
||||
loss = loss_per_element
|
||||
else:
|
||||
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
|
||||
if self.sd.is_flow_matching and prior_pred is None:
|
||||
# outputs should be preprocessed latents
|
||||
sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch)
|
||||
weighting = torch.ones_like(sigmas)
|
||||
loss = (weighting.float() * (pred.float() - target.float()) ** 2).reshape(target.shape[0], -1)
|
||||
|
||||
elif self.train_config.loss_type == "mae":
|
||||
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
|
||||
if self.train_config.loss_type == "mae":
|
||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
Reference in New Issue
Block a user