mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Adjusted flow matching so target noise multiplier works properly with it.
This commit is contained in:
@@ -1008,3 +1008,19 @@ def apply_snr_weight(
|
||||
snr_adjusted_loss = loss * snr_weight
|
||||
|
||||
return snr_adjusted_loss
|
||||
|
||||
|
||||
def precondition_model_outputs_flow_match(model_output, model_input, timestep_tensor, noise_scheduler):
|
||||
mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0)
|
||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||
out_chunks = []
|
||||
# unsqueeze if timestep is zero dim
|
||||
for idx in range(model_output.shape[0]):
|
||||
sigmas = noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim,
|
||||
dtype=model_output.dtype, device=model_output.device)
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
# Preconditioning of the model outputs.
|
||||
out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx]
|
||||
out_chunks.append(out)
|
||||
return torch.cat(out_chunks, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user