Adjusted flow matching so target noise multiplier works properly with it.

This commit is contained in:
Jaret Burkett
2024-08-05 11:40:05 -06:00
parent 0ea27011d5
commit edb7e827ee
4 changed files with 35 additions and 26 deletions

View File

@@ -31,7 +31,7 @@ from jobs.process import BaseSDTrainProcess
from torchvision import transforms
from diffusers import EMAModel
import math
from toolkit.train_tools import precondition_model_outputs_flow_match
def flush():
@@ -328,14 +328,24 @@ class SDTrainer(BaseSDTrainProcess):
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
elif self.sd.is_rectified_flow:
elif self.sd.is_flow_matching:
# only if preconditioning model outputs
# if not preconditioning, (target = noise - batch.latents)
# target = noise - batch.latents
# if preconditioning outputs, target latents
# model_pred = model_pred * (-sigmas) + noisy_model_input
if self.train_config.target_noise_multiplier != 1.0:
# we are adjusting the target noise, need to recompute the noisy latents with
# the noise adjusted above
with torch.no_grad():
noisy_latents = self.sd.add_noise(batch.latents, noise, timesteps).detach()
noise_pred = precondition_model_outputs_flow_match(
noise_pred,
noisy_latents,
timesteps,
self.sd.noise_scheduler
)
target = batch.latents.detach()
else:
target = noise
@@ -383,7 +393,7 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss_per_element
else:
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
if self.sd.is_rectified_flow and prior_pred is None:
if self.sd.is_flow_matching and prior_pred is None:
# outputs should be preprocessed latents
sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch)
weighting = torch.ones_like(sigmas)

View File

@@ -959,7 +959,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
raise ValueError(f"Unknown content_or_style {content_or_style}")
# do flow matching
# if self.sd.is_rectified_flow:
# if self.sd.is_flow_matching:
# u = compute_density_for_timestep_sampling(
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
# batch_size=batch_size,

View File

@@ -166,9 +166,9 @@ class StableDiffusion:
self.config_file = None
self.is_rectified_flow = False
self.is_flow_matching = False
if self.is_flux or self.is_v3 or self.is_auraflow:
self.is_rectified_flow = True
self.is_flow_matching = True
def load_model(self):
if self.is_loaded:
@@ -1337,20 +1337,6 @@ class StableDiffusion:
)
return torch.cat(out_chunks, dim=0)
def precondition_model_outputs_sd3(model_output, model_input, timestep_tensor):
mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0)
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
out_chunks = []
# unsqueeze if timestep is zero dim
for idx in range(model_output.shape[0]):
sigmas = self.noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, dtype=model_output.dtype, device=model_output.device)
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx]
out_chunks.append(out)
return torch.cat(out_chunks, dim=0)
if self.is_xl:
with torch.no_grad():
# 16, 6 for bs of 4
@@ -1614,8 +1600,6 @@ class StableDiffusion:
width=width, # 1024
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
)
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
elif self.is_v3:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
@@ -1624,7 +1608,6 @@ class StableDiffusion:
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
elif self.is_auraflow:
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML

View File

@@ -1008,3 +1008,19 @@ def apply_snr_weight(
snr_adjusted_loss = loss * snr_weight
return snr_adjusted_loss
def precondition_model_outputs_flow_match(model_output, model_input, timestep_tensor, noise_scheduler):
mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0)
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
out_chunks = []
# unsqueeze if timestep is zero dim
for idx in range(model_output.shape[0]):
sigmas = noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim,
dtype=model_output.dtype, device=model_output.device)
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx]
out_chunks.append(out)
return torch.cat(out_chunks, dim=0)