diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 68aa8ca0..59d75988 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -617,33 +617,28 @@ class SDTrainer(BaseSDTrainProcess): stepped_latents = torch.cat(stepped_chunks, dim=0) stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype) - # resize to half the size of the latents - stepped_latents_half = torch.nn.functional.interpolate( - stepped_latents, - size=(stepped_latents.shape[2] // 2, stepped_latents.shape[3] // 2), - mode='bilinear', - align_corners=False - ) - pred_features = self.dfe(stepped_latents.float()) - pred_features_half = self.dfe(stepped_latents_half.float()) + sl = stepped_latents + if len(sl.shape) == 5: + # video B,C,T,H,W + sl = sl.permute(0, 2, 1, 3, 4) # B,T,C,H,W + b, t, c, h, w = sl.shape + sl = sl.reshape(b * t, c, h, w) + pred_features = self.dfe(sl.float()) with torch.no_grad(): - target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32)) - batch_latents_half = torch.nn.functional.interpolate( - batch.latents.to(self.device_torch, dtype=torch.float32), - size=(batch.latents.shape[2] // 2, batch.latents.shape[3] // 2), - mode='bilinear', - align_corners=False - ) - target_features_half = self.dfe(batch_latents_half) + bl = batch.latents + bl = bl.to(self.sd.vae.device) + if len(bl.shape) == 5: + # video B,C,T,H,W + bl = bl.permute(0, 2, 1, 3, 4) # B,T,C,H,W + b, t, c, h, w = bl.shape + bl = bl.reshape(b * t, c, h, w) + target_features = self.dfe(bl.float()) # scale dfe so it is weaker at higher noise levels dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch) dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \ self.train_config.diffusion_feature_extractor_weight * dfe_scaler - - dfe_loss_half = torch.nn.functional.mse_loss(pred_features_half, target_features_half, reduction="none") * \ - self.train_config.diffusion_feature_extractor_weight * dfe_scaler - additional_loss += dfe_loss.mean() + dfe_loss_half.mean() + additional_loss += dfe_loss.mean() elif self.dfe.version == 2: # version 2 # do diffusion feature extraction on target