mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added initial support for finetuning wan i2v WIP
This commit is contained in:
@@ -618,277 +618,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
return loss
|
||||
|
||||
def get_guided_loss_targeted_polarity(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Perform targeted guidance (working title)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
mean_latents = (conditional_latents + unconditional_latents) / 2.0
|
||||
|
||||
unconditional_diff = (unconditional_latents - mean_latents)
|
||||
conditional_diff = (conditional_latents - mean_latents)
|
||||
|
||||
# we need to determine the amount of signal and noise that would be present at the current timestep
|
||||
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
|
||||
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
|
||||
|
||||
# target_noise = noise + unconditional_signal
|
||||
|
||||
conditional_noisy_latents = self.sd.add_noise(
|
||||
mean_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
unconditional_noisy_latents = self.sd.add_noise(
|
||||
mean_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
|
||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||
# This acts as our control to preserve the unaltered parts of the image.
|
||||
baseline_prediction = self.sd.predict_noise(
|
||||
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
|
||||
# double up everything to run it through all at once
|
||||
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0)
|
||||
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
|
||||
# since we are dividing the polarity from the middle out, we need to double our network
|
||||
# weights on training since the convergent point will be at half network strength
|
||||
|
||||
negative_network_weights = [weight * -2.0 for weight in network_weight_list]
|
||||
positive_network_weights = [weight * 2.0 for weight in network_weight_list]
|
||||
cat_network_weight_list = positive_network_weights + negative_network_weights
|
||||
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
|
||||
self.network.multiplier = cat_network_weight_list
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = self.sd.predict_noise(
|
||||
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=cat_timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
||||
|
||||
pred_pos = pred_pos - baseline_prediction
|
||||
pred_neg = pred_neg - baseline_prediction
|
||||
|
||||
pred_loss = torch.nn.functional.mse_loss(
|
||||
pred_pos.float(),
|
||||
unconditional_diff.float(),
|
||||
reduction="none"
|
||||
)
|
||||
pred_loss = pred_loss.mean([1, 2, 3])
|
||||
|
||||
pred_neg_loss = torch.nn.functional.mse_loss(
|
||||
pred_neg.float(),
|
||||
conditional_diff.float(),
|
||||
reduction="none"
|
||||
)
|
||||
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
||||
|
||||
loss = (pred_loss + pred_neg_loss) / 2.0
|
||||
|
||||
# loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
loss.requires_grad_(True)
|
||||
|
||||
return loss
|
||||
|
||||
def get_guided_loss_masked_polarity(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Perform targeted guidance (working title)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
||||
inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents)
|
||||
|
||||
mean_latents = (conditional_latents + unconditional_latents) / 2.0
|
||||
|
||||
# unconditional_diff = (unconditional_latents - mean_latents)
|
||||
# conditional_diff = (conditional_latents - mean_latents)
|
||||
|
||||
# we need to determine the amount of signal and noise that would be present at the current timestep
|
||||
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
|
||||
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
|
||||
|
||||
# make a differential mask
|
||||
differential_mask = torch.abs(conditional_latents - unconditional_latents)
|
||||
max_differential = \
|
||||
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
||||
differential_scaler = 1.0 / max_differential
|
||||
differential_mask = differential_mask * differential_scaler
|
||||
spread_point = 0.1
|
||||
# adjust mask to amplify the differential at 0.1
|
||||
differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point
|
||||
# clip it
|
||||
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
|
||||
|
||||
# target_noise = noise + unconditional_signal
|
||||
|
||||
conditional_noisy_latents = self.sd.add_noise(
|
||||
conditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
unconditional_noisy_latents = self.sd.add_noise(
|
||||
unconditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
inverse_noisy_latents = self.sd.add_noise(
|
||||
inverse_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
|
||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||
# This acts as our control to preserve the unaltered parts of the image.
|
||||
# baseline_prediction = self.sd.predict_noise(
|
||||
# latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
# conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
# timestep=timesteps,
|
||||
# guidance_scale=1.0,
|
||||
# **pred_kwargs # adapter residuals in here
|
||||
# ).detach()
|
||||
|
||||
# double up everything to run it through all at once
|
||||
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
|
||||
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
|
||||
# since we are dividing the polarity from the middle out, we need to double our network
|
||||
# weights on training since the convergent point will be at half network strength
|
||||
|
||||
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
|
||||
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
|
||||
cat_network_weight_list = positive_network_weights + negative_network_weights
|
||||
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
|
||||
self.network.multiplier = cat_network_weight_list
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = self.sd.predict_noise(
|
||||
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=cat_timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
||||
|
||||
# create a loss to balance the mean to 0 between the two predictions
|
||||
differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0
|
||||
|
||||
# pred_pos = pred_pos - baseline_prediction
|
||||
# pred_neg = pred_neg - baseline_prediction
|
||||
|
||||
pred_loss = torch.nn.functional.mse_loss(
|
||||
pred_pos.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
# apply mask
|
||||
pred_loss = pred_loss * (1.0 + differential_mask)
|
||||
pred_loss = pred_loss.mean([1, 2, 3])
|
||||
|
||||
pred_neg_loss = torch.nn.functional.mse_loss(
|
||||
pred_neg.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
# apply inverse mask
|
||||
pred_neg_loss = pred_neg_loss * (1.0 - differential_mask)
|
||||
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
||||
|
||||
# make a loss to balance to losses of the pos and neg so they are equal
|
||||
# differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss)
|
||||
#
|
||||
# differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss
|
||||
#
|
||||
# # add a multiplier to balancing losses to make them the top priority
|
||||
# differential_mean_loss = differential_mean_loss
|
||||
|
||||
# remove the grads from the negative as it is only a balancing loss
|
||||
# pred_neg_loss = pred_neg_loss.detach()
|
||||
|
||||
# loss = pred_loss + pred_neg_loss + differential_mean_loss
|
||||
loss = pred_loss + pred_neg_loss
|
||||
|
||||
# loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
loss.requires_grad_(True)
|
||||
|
||||
return loss
|
||||
|
||||
def get_prior_prediction(
|
||||
self,
|
||||
@@ -985,6 +714,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
batch=batch,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
if was_unet_training:
|
||||
@@ -1021,6 +751,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
timesteps: Union[int, torch.Tensor] = 1,
|
||||
conditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
batch: Optional['DataLoaderBatchDTO'] = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -1034,6 +765,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
|
||||
batch=batch,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -1690,6 +1422,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
@@ -1723,6 +1456,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier
|
||||
|
||||
Reference in New Issue
Block a user