diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 86c198ea..1787f0da 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -523,16 +523,33 @@ class SDTrainer(BaseSDTrainProcess): assert not self.train_config.train_turbo with torch.no_grad(): prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype) + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + lat_height = batch.latents.shape[3] + lat_width = batch.latents.shape[4] + else: + lat_height = batch.latents.shape[2] + lat_width = batch.latents.shape[3] # resize to size of noise_pred - prior_mask = torch.nn.functional.interpolate(prior_mask, size=(noise_pred.shape[2], noise_pred.shape[3]), mode='bicubic') + prior_mask = torch.nn.functional.interpolate(prior_mask, size=(lat_height, lat_width), mode='bicubic') # stack first channel to match channels of noise_pred prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1) + + if len(noise_pred.shape) == 5: + prior_mask = prior_mask.unsqueeze(2) # add time dimension back for video + prior_mask = prior_mask.repeat(1, 1, noise_pred.shape[2], 1, 1) prior_mask_multiplier = 1.0 - prior_mask # scale so it is a mean of 1 prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean() - if self.sd.is_flow_matching: + if hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + elif self.sd.is_flow_matching: target = (noise - batch.latents).detach() else: target = noise @@ -770,9 +787,15 @@ class SDTrainer(BaseSDTrainProcess): # multiply by our mask try: + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + mask_multiplier = mask_multiplier.unsqueeze(2) # add time dimension back for video + mask_multiplier = mask_multiplier.repeat(1, 1, noise_pred.shape[2], 1, 1) loss = loss * mask_multiplier - except: + except Exception as e: # todo handle mask with video models + print("Could not apply mask multiplier to loss") + print(e) pass prior_loss = None @@ -788,11 +811,18 @@ class SDTrainer(BaseSDTrainProcess): print_acc("Prior loss is nan") prior_loss = None else: - prior_loss = prior_loss.mean([1, 2, 3]) + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + prior_loss = prior_loss.mean([1, 2, 3, 4]) + else: + prior_loss = prior_loss.mean([1, 2, 3]) # loss = loss + prior_loss # loss = loss + prior_loss # loss = loss + prior_loss - loss = loss.mean([1, 2, 3]) + if len(noise_pred.shape) == 5: + loss = loss.mean([1, 2, 3, 4]) + else: + loss = loss.mean([1, 2, 3]) # apply loss multiplier before prior loss # multiply by our mask try: @@ -1268,8 +1298,15 @@ class SDTrainer(BaseSDTrainProcess): # upsampling no supported for bfloat16 mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) + if len(noisy_latents.shape) == 5: + # video B,C,T,H,W + h = noisy_latents.shape[3] + w = noisy_latents.shape[4] + else: + h = noisy_latents.shape[2] + w = noisy_latents.shape[3] mask_multiplier = torch.nn.functional.interpolate( - mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3]) + mask_multiplier, size=(h, w) ) # expand to match latents mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)