mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Adjusted flow matching so target noise multiplier works properly with it.
This commit is contained in:
@@ -31,7 +31,7 @@ from jobs.process import BaseSDTrainProcess
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from diffusers import EMAModel
|
from diffusers import EMAModel
|
||||||
import math
|
import math
|
||||||
|
from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||||
|
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
@@ -328,14 +328,24 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# v-parameterization training
|
# v-parameterization training
|
||||||
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
|
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
|
||||||
|
|
||||||
elif self.sd.is_rectified_flow:
|
elif self.sd.is_flow_matching:
|
||||||
# only if preconditioning model outputs
|
# only if preconditioning model outputs
|
||||||
# if not preconditioning, (target = noise - batch.latents)
|
# if not preconditioning, (target = noise - batch.latents)
|
||||||
|
|
||||||
|
|
||||||
# target = noise - batch.latents
|
|
||||||
|
|
||||||
# if preconditioning outputs, target latents
|
# if preconditioning outputs, target latents
|
||||||
|
# model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||||
|
if self.train_config.target_noise_multiplier != 1.0:
|
||||||
|
# we are adjusting the target noise, need to recompute the noisy latents with
|
||||||
|
# the noise adjusted above
|
||||||
|
with torch.no_grad():
|
||||||
|
noisy_latents = self.sd.add_noise(batch.latents, noise, timesteps).detach()
|
||||||
|
|
||||||
|
noise_pred = precondition_model_outputs_flow_match(
|
||||||
|
noise_pred,
|
||||||
|
noisy_latents,
|
||||||
|
timesteps,
|
||||||
|
self.sd.noise_scheduler
|
||||||
|
)
|
||||||
target = batch.latents.detach()
|
target = batch.latents.detach()
|
||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
@@ -383,7 +393,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
loss = loss_per_element
|
loss = loss_per_element
|
||||||
else:
|
else:
|
||||||
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
|
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
|
||||||
if self.sd.is_rectified_flow and prior_pred is None:
|
if self.sd.is_flow_matching and prior_pred is None:
|
||||||
# outputs should be preprocessed latents
|
# outputs should be preprocessed latents
|
||||||
sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch)
|
sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch)
|
||||||
weighting = torch.ones_like(sigmas)
|
weighting = torch.ones_like(sigmas)
|
||||||
|
|||||||
@@ -959,7 +959,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
||||||
|
|
||||||
# do flow matching
|
# do flow matching
|
||||||
# if self.sd.is_rectified_flow:
|
# if self.sd.is_flow_matching:
|
||||||
# u = compute_density_for_timestep_sampling(
|
# u = compute_density_for_timestep_sampling(
|
||||||
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
|
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
|
||||||
# batch_size=batch_size,
|
# batch_size=batch_size,
|
||||||
|
|||||||
@@ -166,9 +166,9 @@ class StableDiffusion:
|
|||||||
|
|
||||||
self.config_file = None
|
self.config_file = None
|
||||||
|
|
||||||
self.is_rectified_flow = False
|
self.is_flow_matching = False
|
||||||
if self.is_flux or self.is_v3 or self.is_auraflow:
|
if self.is_flux or self.is_v3 or self.is_auraflow:
|
||||||
self.is_rectified_flow = True
|
self.is_flow_matching = True
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.is_loaded:
|
if self.is_loaded:
|
||||||
@@ -1337,20 +1337,6 @@ class StableDiffusion:
|
|||||||
)
|
)
|
||||||
return torch.cat(out_chunks, dim=0)
|
return torch.cat(out_chunks, dim=0)
|
||||||
|
|
||||||
def precondition_model_outputs_sd3(model_output, model_input, timestep_tensor):
|
|
||||||
mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0)
|
|
||||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
|
||||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
|
||||||
out_chunks = []
|
|
||||||
# unsqueeze if timestep is zero dim
|
|
||||||
for idx in range(model_output.shape[0]):
|
|
||||||
sigmas = self.noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, dtype=model_output.dtype, device=model_output.device)
|
|
||||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
|
||||||
# Preconditioning of the model outputs.
|
|
||||||
out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx]
|
|
||||||
out_chunks.append(out)
|
|
||||||
return torch.cat(out_chunks, dim=0)
|
|
||||||
|
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# 16, 6 for bs of 4
|
# 16, 6 for bs of 4
|
||||||
@@ -1614,8 +1600,6 @@ class StableDiffusion:
|
|||||||
width=width, # 1024
|
width=width, # 1024
|
||||||
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
|
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
|
||||||
)
|
)
|
||||||
|
|
||||||
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
|
|
||||||
elif self.is_v3:
|
elif self.is_v3:
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||||
@@ -1624,7 +1608,6 @@ class StableDiffusion:
|
|||||||
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype),
|
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
).sample
|
).sample
|
||||||
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
|
|
||||||
elif self.is_auraflow:
|
elif self.is_auraflow:
|
||||||
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
|
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
|
||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
|||||||
@@ -1008,3 +1008,19 @@ def apply_snr_weight(
|
|||||||
snr_adjusted_loss = loss * snr_weight
|
snr_adjusted_loss = loss * snr_weight
|
||||||
|
|
||||||
return snr_adjusted_loss
|
return snr_adjusted_loss
|
||||||
|
|
||||||
|
|
||||||
|
def precondition_model_outputs_flow_match(model_output, model_input, timestep_tensor, noise_scheduler):
|
||||||
|
mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0)
|
||||||
|
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||||
|
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||||
|
out_chunks = []
|
||||||
|
# unsqueeze if timestep is zero dim
|
||||||
|
for idx in range(model_output.shape[0]):
|
||||||
|
sigmas = noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim,
|
||||||
|
dtype=model_output.dtype, device=model_output.device)
|
||||||
|
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||||
|
# Preconditioning of the model outputs.
|
||||||
|
out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx]
|
||||||
|
out_chunks.append(out)
|
||||||
|
return torch.cat(out_chunks, dim=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user