mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Allow masked losses with video models
This commit is contained in:
@@ -523,16 +523,33 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
assert not self.train_config.train_turbo
|
||||
with torch.no_grad():
|
||||
prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype)
|
||||
if len(noise_pred.shape) == 5:
|
||||
# video B,C,T,H,W
|
||||
lat_height = batch.latents.shape[3]
|
||||
lat_width = batch.latents.shape[4]
|
||||
else:
|
||||
lat_height = batch.latents.shape[2]
|
||||
lat_width = batch.latents.shape[3]
|
||||
# resize to size of noise_pred
|
||||
prior_mask = torch.nn.functional.interpolate(prior_mask, size=(noise_pred.shape[2], noise_pred.shape[3]), mode='bicubic')
|
||||
prior_mask = torch.nn.functional.interpolate(prior_mask, size=(lat_height, lat_width), mode='bicubic')
|
||||
# stack first channel to match channels of noise_pred
|
||||
prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1)
|
||||
|
||||
if len(noise_pred.shape) == 5:
|
||||
prior_mask = prior_mask.unsqueeze(2) # add time dimension back for video
|
||||
prior_mask = prior_mask.repeat(1, 1, noise_pred.shape[2], 1, 1)
|
||||
|
||||
prior_mask_multiplier = 1.0 - prior_mask
|
||||
|
||||
# scale so it is a mean of 1
|
||||
prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean()
|
||||
if self.sd.is_flow_matching:
|
||||
if hasattr(self.sd, 'get_loss_target'):
|
||||
target = self.sd.get_loss_target(
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
timesteps=timesteps,
|
||||
).detach()
|
||||
elif self.sd.is_flow_matching:
|
||||
target = (noise - batch.latents).detach()
|
||||
else:
|
||||
target = noise
|
||||
@@ -770,9 +787,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# multiply by our mask
|
||||
try:
|
||||
if len(noise_pred.shape) == 5:
|
||||
# video B,C,T,H,W
|
||||
mask_multiplier = mask_multiplier.unsqueeze(2) # add time dimension back for video
|
||||
mask_multiplier = mask_multiplier.repeat(1, 1, noise_pred.shape[2], 1, 1)
|
||||
loss = loss * mask_multiplier
|
||||
except:
|
||||
except Exception as e:
|
||||
# todo handle mask with video models
|
||||
print("Could not apply mask multiplier to loss")
|
||||
print(e)
|
||||
pass
|
||||
|
||||
prior_loss = None
|
||||
@@ -788,11 +811,18 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
print_acc("Prior loss is nan")
|
||||
prior_loss = None
|
||||
else:
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
if len(noise_pred.shape) == 5:
|
||||
# video B,C,T,H,W
|
||||
prior_loss = prior_loss.mean([1, 2, 3, 4])
|
||||
else:
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
# loss = loss + prior_loss
|
||||
# loss = loss + prior_loss
|
||||
# loss = loss + prior_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
if len(noise_pred.shape) == 5:
|
||||
loss = loss.mean([1, 2, 3, 4])
|
||||
else:
|
||||
loss = loss.mean([1, 2, 3])
|
||||
# apply loss multiplier before prior loss
|
||||
# multiply by our mask
|
||||
try:
|
||||
@@ -1268,8 +1298,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# upsampling no supported for bfloat16
|
||||
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
||||
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
||||
if len(noisy_latents.shape) == 5:
|
||||
# video B,C,T,H,W
|
||||
h = noisy_latents.shape[3]
|
||||
w = noisy_latents.shape[4]
|
||||
else:
|
||||
h = noisy_latents.shape[2]
|
||||
w = noisy_latents.shape[3]
|
||||
mask_multiplier = torch.nn.functional.interpolate(
|
||||
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
||||
mask_multiplier, size=(h, w)
|
||||
)
|
||||
# expand to match latents
|
||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||
|
||||
Reference in New Issue
Block a user