diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 4d8489ba..7332a123 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -1150,7 +1150,7 @@ class DiffusionFeatureExtractor9(nn.Module): perceptual_loss = torch.nn.functional.mse_loss( pred.float(), target.float(), reduction="none" ) - velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.001) ** 2) + velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.1) ** 2) loss_perceptual = (perceptual_loss * velocity_equiv_weight).mean() if self.do_partial_step: