mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Allow masked losses with video models
This commit is contained in:
@@ -523,16 +523,33 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
assert not self.train_config.train_turbo
|
assert not self.train_config.train_turbo
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype)
|
prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype)
|
||||||
|
if len(noise_pred.shape) == 5:
|
||||||
|
# video B,C,T,H,W
|
||||||
|
lat_height = batch.latents.shape[3]
|
||||||
|
lat_width = batch.latents.shape[4]
|
||||||
|
else:
|
||||||
|
lat_height = batch.latents.shape[2]
|
||||||
|
lat_width = batch.latents.shape[3]
|
||||||
# resize to size of noise_pred
|
# resize to size of noise_pred
|
||||||
prior_mask = torch.nn.functional.interpolate(prior_mask, size=(noise_pred.shape[2], noise_pred.shape[3]), mode='bicubic')
|
prior_mask = torch.nn.functional.interpolate(prior_mask, size=(lat_height, lat_width), mode='bicubic')
|
||||||
# stack first channel to match channels of noise_pred
|
# stack first channel to match channels of noise_pred
|
||||||
prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1)
|
prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1)
|
||||||
|
|
||||||
|
if len(noise_pred.shape) == 5:
|
||||||
|
prior_mask = prior_mask.unsqueeze(2) # add time dimension back for video
|
||||||
|
prior_mask = prior_mask.repeat(1, 1, noise_pred.shape[2], 1, 1)
|
||||||
|
|
||||||
prior_mask_multiplier = 1.0 - prior_mask
|
prior_mask_multiplier = 1.0 - prior_mask
|
||||||
|
|
||||||
# scale so it is a mean of 1
|
# scale so it is a mean of 1
|
||||||
prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean()
|
prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean()
|
||||||
if self.sd.is_flow_matching:
|
if hasattr(self.sd, 'get_loss_target'):
|
||||||
|
target = self.sd.get_loss_target(
|
||||||
|
noise=noise,
|
||||||
|
batch=batch,
|
||||||
|
timesteps=timesteps,
|
||||||
|
).detach()
|
||||||
|
elif self.sd.is_flow_matching:
|
||||||
target = (noise - batch.latents).detach()
|
target = (noise - batch.latents).detach()
|
||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
@@ -770,9 +787,15 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
# multiply by our mask
|
# multiply by our mask
|
||||||
try:
|
try:
|
||||||
|
if len(noise_pred.shape) == 5:
|
||||||
|
# video B,C,T,H,W
|
||||||
|
mask_multiplier = mask_multiplier.unsqueeze(2) # add time dimension back for video
|
||||||
|
mask_multiplier = mask_multiplier.repeat(1, 1, noise_pred.shape[2], 1, 1)
|
||||||
loss = loss * mask_multiplier
|
loss = loss * mask_multiplier
|
||||||
except:
|
except Exception as e:
|
||||||
# todo handle mask with video models
|
# todo handle mask with video models
|
||||||
|
print("Could not apply mask multiplier to loss")
|
||||||
|
print(e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
prior_loss = None
|
prior_loss = None
|
||||||
@@ -788,11 +811,18 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
print_acc("Prior loss is nan")
|
print_acc("Prior loss is nan")
|
||||||
prior_loss = None
|
prior_loss = None
|
||||||
else:
|
else:
|
||||||
prior_loss = prior_loss.mean([1, 2, 3])
|
if len(noise_pred.shape) == 5:
|
||||||
|
# video B,C,T,H,W
|
||||||
|
prior_loss = prior_loss.mean([1, 2, 3, 4])
|
||||||
|
else:
|
||||||
|
prior_loss = prior_loss.mean([1, 2, 3])
|
||||||
# loss = loss + prior_loss
|
# loss = loss + prior_loss
|
||||||
# loss = loss + prior_loss
|
# loss = loss + prior_loss
|
||||||
# loss = loss + prior_loss
|
# loss = loss + prior_loss
|
||||||
loss = loss.mean([1, 2, 3])
|
if len(noise_pred.shape) == 5:
|
||||||
|
loss = loss.mean([1, 2, 3, 4])
|
||||||
|
else:
|
||||||
|
loss = loss.mean([1, 2, 3])
|
||||||
# apply loss multiplier before prior loss
|
# apply loss multiplier before prior loss
|
||||||
# multiply by our mask
|
# multiply by our mask
|
||||||
try:
|
try:
|
||||||
@@ -1268,8 +1298,15 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# upsampling no supported for bfloat16
|
# upsampling no supported for bfloat16
|
||||||
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
||||||
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
||||||
|
if len(noisy_latents.shape) == 5:
|
||||||
|
# video B,C,T,H,W
|
||||||
|
h = noisy_latents.shape[3]
|
||||||
|
w = noisy_latents.shape[4]
|
||||||
|
else:
|
||||||
|
h = noisy_latents.shape[2]
|
||||||
|
w = noisy_latents.shape[3]
|
||||||
mask_multiplier = torch.nn.functional.interpolate(
|
mask_multiplier = torch.nn.functional.interpolate(
|
||||||
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
mask_multiplier, size=(h, w)
|
||||||
)
|
)
|
||||||
# expand to match latents
|
# expand to match latents
|
||||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||||
|
|||||||
Reference in New Issue
Block a user