mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Added support for polarity guidance for flow matching models
This commit is contained in:
@@ -419,6 +419,15 @@ def get_guided_loss_polarity(
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
target_pos = noise
|
||||
target_neg = noise
|
||||
|
||||
if sd.is_flow_matching:
|
||||
# set the timesteps for flow matching as linear since we will do weighing
|
||||
sd.noise_scheduler.set_train_timesteps(1000, device, linear=True)
|
||||
target_pos = (noise - conditional_latents).detach()
|
||||
target_neg = (noise - unconditional_latents).detach()
|
||||
|
||||
conditional_noisy_latents = sd.add_noise(
|
||||
conditional_latents,
|
||||
noise,
|
||||
@@ -459,19 +468,24 @@ def get_guided_loss_polarity(
|
||||
|
||||
pred_loss = torch.nn.functional.mse_loss(
|
||||
pred_pos.float(),
|
||||
noise.float(),
|
||||
target_pos.float(),
|
||||
reduction="none"
|
||||
)
|
||||
# pred_loss = pred_loss.mean([1, 2, 3])
|
||||
|
||||
pred_neg_loss = torch.nn.functional.mse_loss(
|
||||
pred_neg.float(),
|
||||
noise.float(),
|
||||
target_neg.float(),
|
||||
reduction="none"
|
||||
)
|
||||
|
||||
loss = pred_loss + pred_neg_loss
|
||||
|
||||
if sd.is_flow_matching:
|
||||
timestep_weight = sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype).detach()
|
||||
loss = loss * timestep_weight
|
||||
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = loss.mean()
|
||||
if scaler is not None:
|
||||
|
||||
Reference in New Issue
Block a user