mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Adjust DFE to handle 5 dimension latent spaces
This commit is contained in:
@@ -617,33 +617,28 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
|
||||
stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype)
|
||||
# resize to half the size of the latents
|
||||
stepped_latents_half = torch.nn.functional.interpolate(
|
||||
stepped_latents,
|
||||
size=(stepped_latents.shape[2] // 2, stepped_latents.shape[3] // 2),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
pred_features = self.dfe(stepped_latents.float())
|
||||
pred_features_half = self.dfe(stepped_latents_half.float())
|
||||
sl = stepped_latents
|
||||
if len(sl.shape) == 5:
|
||||
# video B,C,T,H,W
|
||||
sl = sl.permute(0, 2, 1, 3, 4) # B,T,C,H,W
|
||||
b, t, c, h, w = sl.shape
|
||||
sl = sl.reshape(b * t, c, h, w)
|
||||
pred_features = self.dfe(sl.float())
|
||||
with torch.no_grad():
|
||||
target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32))
|
||||
batch_latents_half = torch.nn.functional.interpolate(
|
||||
batch.latents.to(self.device_torch, dtype=torch.float32),
|
||||
size=(batch.latents.shape[2] // 2, batch.latents.shape[3] // 2),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
target_features_half = self.dfe(batch_latents_half)
|
||||
bl = batch.latents
|
||||
bl = bl.to(self.sd.vae.device)
|
||||
if len(bl.shape) == 5:
|
||||
# video B,C,T,H,W
|
||||
bl = bl.permute(0, 2, 1, 3, 4) # B,T,C,H,W
|
||||
b, t, c, h, w = bl.shape
|
||||
bl = bl.reshape(b * t, c, h, w)
|
||||
target_features = self.dfe(bl.float())
|
||||
# scale dfe so it is weaker at higher noise levels
|
||||
dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch)
|
||||
|
||||
dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \
|
||||
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
|
||||
|
||||
dfe_loss_half = torch.nn.functional.mse_loss(pred_features_half, target_features_half, reduction="none") * \
|
||||
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
|
||||
additional_loss += dfe_loss.mean() + dfe_loss_half.mean()
|
||||
additional_loss += dfe_loss.mean()
|
||||
elif self.dfe.version == 2:
|
||||
# version 2
|
||||
# do diffusion feature extraction on target
|
||||
|
||||
Reference in New Issue
Block a user