mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-14 15:07:22 +00:00
Adjusted flow matching so target noise multiplier works properly with it.
This commit is contained in:
@@ -166,9 +166,9 @@ class StableDiffusion:
|
||||
|
||||
self.config_file = None
|
||||
|
||||
self.is_rectified_flow = False
|
||||
self.is_flow_matching = False
|
||||
if self.is_flux or self.is_v3 or self.is_auraflow:
|
||||
self.is_rectified_flow = True
|
||||
self.is_flow_matching = True
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
@@ -1337,20 +1337,6 @@ class StableDiffusion:
|
||||
)
|
||||
return torch.cat(out_chunks, dim=0)
|
||||
|
||||
def precondition_model_outputs_sd3(model_output, model_input, timestep_tensor):
|
||||
mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0)
|
||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||
out_chunks = []
|
||||
# unsqueeze if timestep is zero dim
|
||||
for idx in range(model_output.shape[0]):
|
||||
sigmas = self.noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, dtype=model_output.dtype, device=model_output.device)
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
# Preconditioning of the model outputs.
|
||||
out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx]
|
||||
out_chunks.append(out)
|
||||
return torch.cat(out_chunks, dim=0)
|
||||
|
||||
if self.is_xl:
|
||||
with torch.no_grad():
|
||||
# 16, 6 for bs of 4
|
||||
@@ -1614,8 +1600,6 @@ class StableDiffusion:
|
||||
width=width, # 1024
|
||||
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
|
||||
)
|
||||
|
||||
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
|
||||
elif self.is_v3:
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
@@ -1624,7 +1608,6 @@ class StableDiffusion:
|
||||
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
|
||||
elif self.is_auraflow:
|
||||
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
|
||||
Reference in New Issue
Block a user