Adjust DFE to handle 5 dimension latent spaces

This commit is contained in:
Jaret Burkett
2025-10-27 07:48:44 -06:00
parent 8c12977891
commit 42e5e3cd1c

View File

@@ -617,33 +617,28 @@ class SDTrainer(BaseSDTrainProcess):
stepped_latents = torch.cat(stepped_chunks, dim=0)
stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype)
# resize to half the size of the latents
stepped_latents_half = torch.nn.functional.interpolate(
stepped_latents,
size=(stepped_latents.shape[2] // 2, stepped_latents.shape[3] // 2),
mode='bilinear',
align_corners=False
)
pred_features = self.dfe(stepped_latents.float())
pred_features_half = self.dfe(stepped_latents_half.float())
sl = stepped_latents
if len(sl.shape) == 5:
# video B,C,T,H,W
sl = sl.permute(0, 2, 1, 3, 4) # B,T,C,H,W
b, t, c, h, w = sl.shape
sl = sl.reshape(b * t, c, h, w)
pred_features = self.dfe(sl.float())
with torch.no_grad():
target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32))
batch_latents_half = torch.nn.functional.interpolate(
batch.latents.to(self.device_torch, dtype=torch.float32),
size=(batch.latents.shape[2] // 2, batch.latents.shape[3] // 2),
mode='bilinear',
align_corners=False
)
target_features_half = self.dfe(batch_latents_half)
bl = batch.latents
bl = bl.to(self.sd.vae.device)
if len(bl.shape) == 5:
# video B,C,T,H,W
bl = bl.permute(0, 2, 1, 3, 4) # B,T,C,H,W
b, t, c, h, w = bl.shape
bl = bl.reshape(b * t, c, h, w)
target_features = self.dfe(bl.float())
# scale dfe so it is weaker at higher noise levels
dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch)
dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
dfe_loss_half = torch.nn.functional.mse_loss(pred_features_half, target_features_half, reduction="none") * \
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
additional_loss += dfe_loss.mean() + dfe_loss_half.mean()
additional_loss += dfe_loss.mean()
elif self.dfe.version == 2:
# version 2
# do diffusion feature extraction on target