Allow masked losses with video models

This commit is contained in:
Jaret Burkett
2025-09-30 14:57:07 -06:00
parent 67ed563e03
commit 2ba4000704

View File

@@ -523,16 +523,33 @@ class SDTrainer(BaseSDTrainProcess):
assert not self.train_config.train_turbo
with torch.no_grad():
prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype)
if len(noise_pred.shape) == 5:
# video B,C,T,H,W
lat_height = batch.latents.shape[3]
lat_width = batch.latents.shape[4]
else:
lat_height = batch.latents.shape[2]
lat_width = batch.latents.shape[3]
# resize to size of noise_pred
prior_mask = torch.nn.functional.interpolate(prior_mask, size=(noise_pred.shape[2], noise_pred.shape[3]), mode='bicubic')
prior_mask = torch.nn.functional.interpolate(prior_mask, size=(lat_height, lat_width), mode='bicubic')
# stack first channel to match channels of noise_pred
prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1)
if len(noise_pred.shape) == 5:
prior_mask = prior_mask.unsqueeze(2) # add time dimension back for video
prior_mask = prior_mask.repeat(1, 1, noise_pred.shape[2], 1, 1)
prior_mask_multiplier = 1.0 - prior_mask
# scale so it is a mean of 1
prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean()
if self.sd.is_flow_matching:
if hasattr(self.sd, 'get_loss_target'):
target = self.sd.get_loss_target(
noise=noise,
batch=batch,
timesteps=timesteps,
).detach()
elif self.sd.is_flow_matching:
target = (noise - batch.latents).detach()
else:
target = noise
@@ -770,9 +787,15 @@ class SDTrainer(BaseSDTrainProcess):
# multiply by our mask
try:
if len(noise_pred.shape) == 5:
# video B,C,T,H,W
mask_multiplier = mask_multiplier.unsqueeze(2) # add time dimension back for video
mask_multiplier = mask_multiplier.repeat(1, 1, noise_pred.shape[2], 1, 1)
loss = loss * mask_multiplier
except:
except Exception as e:
# todo handle mask with video models
print("Could not apply mask multiplier to loss")
print(e)
pass
prior_loss = None
@@ -788,11 +811,18 @@ class SDTrainer(BaseSDTrainProcess):
print_acc("Prior loss is nan")
prior_loss = None
else:
prior_loss = prior_loss.mean([1, 2, 3])
if len(noise_pred.shape) == 5:
# video B,C,T,H,W
prior_loss = prior_loss.mean([1, 2, 3, 4])
else:
prior_loss = prior_loss.mean([1, 2, 3])
# loss = loss + prior_loss
# loss = loss + prior_loss
# loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
if len(noise_pred.shape) == 5:
loss = loss.mean([1, 2, 3, 4])
else:
loss = loss.mean([1, 2, 3])
# apply loss multiplier before prior loss
# multiply by our mask
try:
@@ -1268,8 +1298,15 @@ class SDTrainer(BaseSDTrainProcess):
# upsampling no supported for bfloat16
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
if len(noisy_latents.shape) == 5:
# video B,C,T,H,W
h = noisy_latents.shape[3]
w = noisy_latents.shape[4]
else:
h = noisy_latents.shape[2]
w = noisy_latents.shape[3]
mask_multiplier = torch.nn.functional.interpolate(
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
mask_multiplier, size=(h, w)
)
# expand to match latents
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)