Added support for polarity guidance for flow matching models

This commit is contained in:
Jaret Burkett
2024-08-15 12:22:00 -06:00
parent b99d36dfdb
commit 0355662e8e

View File

@@ -419,6 +419,15 @@ def get_guided_loss_polarity(
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
target_pos = noise
target_neg = noise
if sd.is_flow_matching:
# set the timesteps for flow matching as linear since we will do weighing
sd.noise_scheduler.set_train_timesteps(1000, device, linear=True)
target_pos = (noise - conditional_latents).detach()
target_neg = (noise - unconditional_latents).detach()
conditional_noisy_latents = sd.add_noise(
conditional_latents,
noise,
@@ -459,19 +468,24 @@ def get_guided_loss_polarity(
pred_loss = torch.nn.functional.mse_loss(
pred_pos.float(),
noise.float(),
target_pos.float(),
reduction="none"
)
# pred_loss = pred_loss.mean([1, 2, 3])
pred_neg_loss = torch.nn.functional.mse_loss(
pred_neg.float(),
noise.float(),
target_neg.float(),
reduction="none"
)
loss = pred_loss + pred_neg_loss
if sd.is_flow_matching:
timestep_weight = sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype).detach()
loss = loss * timestep_weight
loss = loss.mean([1, 2, 3])
loss = loss.mean()
if scaler is not None: