Adjusted flow matching so target noise multiplier works properly with it.

This commit is contained in:
Jaret Burkett
2024-08-05 11:40:05 -06:00
parent 0ea27011d5
commit edb7e827ee
4 changed files with 35 additions and 26 deletions

View File

@@ -1008,3 +1008,19 @@ def apply_snr_weight(
snr_adjusted_loss = loss * snr_weight
return snr_adjusted_loss
def precondition_model_outputs_flow_match(model_output, model_input, timestep_tensor, noise_scheduler):
mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0)
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
out_chunks = []
# unsqueeze if timestep is zero dim
for idx in range(model_output.shape[0]):
sigmas = noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim,
dtype=model_output.dtype, device=model_output.device)
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx]
out_chunks.append(out)
return torch.cat(out_chunks, dim=0)