Change the velocity weight cap on dfe 9

This commit is contained in:
Jaret Burkett
2026-05-07 07:37:05 -06:00
parent 6bb8acbffc
commit a12ddd72a1

View File

@@ -1150,7 +1150,7 @@ class DiffusionFeatureExtractor9(nn.Module):
perceptual_loss = torch.nn.functional.mse_loss(
pred.float(), target.float(), reduction="none"
)
velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.001) ** 2)
velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.1) ** 2)
loss_perceptual = (perceptual_loss * velocity_equiv_weight).mean()
if self.do_partial_step: