Added initial support for finetuning wan i2v WIP

This commit is contained in:
Jaret Burkett
2025-04-07 20:34:38 -06:00
parent 38ad5a4644
commit a8680c75eb
10 changed files with 575 additions and 286 deletions

View File

@@ -618,277 +618,6 @@ class SDTrainer(BaseSDTrainProcess):
return loss
def get_guided_loss_targeted_polarity(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
dtype = get_torch_dtype(self.train_config.dtype)
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
mean_latents = (conditional_latents + unconditional_latents) / 2.0
unconditional_diff = (unconditional_latents - mean_latents)
conditional_diff = (conditional_latents - mean_latents)
# we need to determine the amount of signal and noise that would be present at the current timestep
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
# target_noise = noise + unconditional_signal
conditional_noisy_latents = self.sd.add_noise(
mean_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = self.sd.add_noise(
mean_latents,
noise,
timesteps
).detach()
# Disable the LoRA network so we can predict parent network knowledge without it
self.network.is_active = False
self.sd.unet.eval()
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
baseline_prediction = self.sd.predict_noise(
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
# since we are dividing the polarity from the middle out, we need to double our network
# weights on training since the convergent point will be at half network strength
negative_network_weights = [weight * -2.0 for weight in network_weight_list]
positive_network_weights = [weight * 2.0 for weight in network_weight_list]
cat_network_weight_list = positive_network_weights + negative_network_weights
# turn the LoRA network back on.
self.sd.unet.train()
self.network.is_active = True
self.network.multiplier = cat_network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = self.sd.predict_noise(
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
pred_pos = pred_pos - baseline_prediction
pred_neg = pred_neg - baseline_prediction
pred_loss = torch.nn.functional.mse_loss(
pred_pos.float(),
unconditional_diff.float(),
reduction="none"
)
pred_loss = pred_loss.mean([1, 2, 3])
pred_neg_loss = torch.nn.functional.mse_loss(
pred_neg.float(),
conditional_diff.float(),
reduction="none"
)
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
loss = (pred_loss + pred_neg_loss) / 2.0
# loss = self.apply_snr(loss, timesteps)
loss = loss.mean()
self.accelerator.backward(loss)
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
def get_guided_loss_masked_polarity(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
dtype = get_torch_dtype(self.train_config.dtype)
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents)
mean_latents = (conditional_latents + unconditional_latents) / 2.0
# unconditional_diff = (unconditional_latents - mean_latents)
# conditional_diff = (conditional_latents - mean_latents)
# we need to determine the amount of signal and noise that would be present at the current timestep
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
# make a differential mask
differential_mask = torch.abs(conditional_latents - unconditional_latents)
max_differential = \
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
differential_scaler = 1.0 / max_differential
differential_mask = differential_mask * differential_scaler
spread_point = 0.1
# adjust mask to amplify the differential at 0.1
differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point
# clip it
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
# target_noise = noise + unconditional_signal
conditional_noisy_latents = self.sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = self.sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
inverse_noisy_latents = self.sd.add_noise(
inverse_latents,
noise,
timesteps
).detach()
# Disable the LoRA network so we can predict parent network knowledge without it
self.network.is_active = False
self.sd.unet.eval()
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
# baseline_prediction = self.sd.predict_noise(
# latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
# conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
# timestep=timesteps,
# guidance_scale=1.0,
# **pred_kwargs # adapter residuals in here
# ).detach()
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
# since we are dividing the polarity from the middle out, we need to double our network
# weights on training since the convergent point will be at half network strength
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
cat_network_weight_list = positive_network_weights + negative_network_weights
# turn the LoRA network back on.
self.sd.unet.train()
self.network.is_active = True
self.network.multiplier = cat_network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = self.sd.predict_noise(
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
# create a loss to balance the mean to 0 between the two predictions
differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0
# pred_pos = pred_pos - baseline_prediction
# pred_neg = pred_neg - baseline_prediction
pred_loss = torch.nn.functional.mse_loss(
pred_pos.float(),
noise.float(),
reduction="none"
)
# apply mask
pred_loss = pred_loss * (1.0 + differential_mask)
pred_loss = pred_loss.mean([1, 2, 3])
pred_neg_loss = torch.nn.functional.mse_loss(
pred_neg.float(),
noise.float(),
reduction="none"
)
# apply inverse mask
pred_neg_loss = pred_neg_loss * (1.0 - differential_mask)
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
# make a loss to balance to losses of the pos and neg so they are equal
# differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss)
#
# differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss
#
# # add a multiplier to balancing losses to make them the top priority
# differential_mean_loss = differential_mean_loss
# remove the grads from the negative as it is only a balancing loss
# pred_neg_loss = pred_neg_loss.detach()
# loss = pred_loss + pred_neg_loss + differential_mean_loss
loss = pred_loss + pred_neg_loss
# loss = self.apply_snr(loss, timesteps)
loss = loss.mean()
self.accelerator.backward(loss)
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
def get_prior_prediction(
self,
@@ -985,6 +714,7 @@ class SDTrainer(BaseSDTrainProcess):
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
rescale_cfg=self.train_config.cfg_rescale,
batch=batch,
**pred_kwargs # adapter residuals in here
)
if was_unet_training:
@@ -1021,6 +751,7 @@ class SDTrainer(BaseSDTrainProcess):
timesteps: Union[int, torch.Tensor] = 1,
conditional_embeds: Union[PromptEmbeds, None] = None,
unconditional_embeds: Union[PromptEmbeds, None] = None,
batch: Optional['DataLoaderBatchDTO'] = None,
**kwargs,
):
dtype = get_torch_dtype(self.train_config.dtype)
@@ -1034,6 +765,7 @@ class SDTrainer(BaseSDTrainProcess):
detach_unconditional=False,
rescale_cfg=self.train_config.cfg_rescale,
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
batch=batch,
**kwargs
)
@@ -1690,6 +1422,7 @@ class SDTrainer(BaseSDTrainProcess):
timesteps=timesteps,
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
self.after_unet_predict()
@@ -1723,6 +1456,7 @@ class SDTrainer(BaseSDTrainProcess):
timesteps=timesteps,
conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier