mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added experimental modified sigma sqrt weight mapping for linear timestep scheduling for flowmatching
This commit is contained in:
@@ -375,12 +375,18 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
|
||||
loss = loss_per_element
|
||||
else:
|
||||
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
|
||||
|
||||
if self.train_config.loss_type == "mae":
|
||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
# handle linear timesteps and only adjust the weight of the timesteps
|
||||
if self.sd.is_flow_matching and self.train_config.linear_timesteps:
|
||||
# calculate the weights for the timesteps
|
||||
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype)
|
||||
loss = loss * timestep_weight
|
||||
|
||||
if self.train_config.do_prior_divergence and prior_pred is not None:
|
||||
loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -5,6 +6,45 @@ import torch
|
||||
|
||||
|
||||
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
# create weights for timesteps
|
||||
num_timesteps = 1000
|
||||
|
||||
# generate the multiplier based on cosmap loss weighing
|
||||
# this is only used on linear timesteps for now
|
||||
|
||||
# cosine map weighing is higher in the middle and lower at the ends
|
||||
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
|
||||
# cosmap_weighing = 2 / (math.pi * bot)
|
||||
|
||||
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
|
||||
sigma_sqrt_weighing = (self.sigmas ** -2.0).float()
|
||||
# clip at 1e4 (1e6 is too high)
|
||||
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
|
||||
# bring to a mean of 1
|
||||
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
|
||||
|
||||
# Create linear timesteps from 1000 to 0
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
|
||||
|
||||
self.linear_timesteps = timesteps
|
||||
# self.linear_timesteps_weights = cosmap_weighing
|
||||
self.linear_timesteps_weights = sigma_sqrt_weighing
|
||||
|
||||
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
|
||||
pass
|
||||
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
# Get the indices of the timesteps
|
||||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
# Get the weights for the timesteps
|
||||
weights = self.linear_timesteps_weights[step_indices].flatten()
|
||||
|
||||
return weights
|
||||
|
||||
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
|
||||
sigmas = self.sigmas.to(device=device, dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user