mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
813 lines
28 KiB
Python
813 lines
28 KiB
Python
import torch
|
|
from typing import Literal, Optional
|
|
|
|
from toolkit.basic import value_map
|
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
|
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
|
from toolkit.stable_diffusion_model import StableDiffusion
|
|
from toolkit.train_tools import get_torch_dtype
|
|
from toolkit.config_modules import TrainConfig
|
|
|
|
GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"]
|
|
|
|
DIFFERENTIAL_SCALER = 0.2
|
|
|
|
|
|
# DIFFERENTIAL_SCALER = 0.25
|
|
|
|
|
|
def get_differential_mask(
|
|
conditional_latents: torch.Tensor,
|
|
unconditional_latents: torch.Tensor,
|
|
threshold: float = 0.2,
|
|
gradient: bool = False,
|
|
):
|
|
# make a differential mask
|
|
differential_mask = torch.abs(conditional_latents - unconditional_latents)
|
|
max_differential = \
|
|
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
|
differential_scaler = 1.0 / max_differential
|
|
differential_mask = differential_mask * differential_scaler
|
|
|
|
if gradient:
|
|
# wew need to scale it to 0-1
|
|
# differential_mask = differential_mask - differential_mask.min()
|
|
# differential_mask = differential_mask / differential_mask.max()
|
|
# add 0.2 threshold to both sides and clip
|
|
differential_mask = value_map(
|
|
differential_mask,
|
|
differential_mask.min(),
|
|
differential_mask.max(),
|
|
0 - threshold,
|
|
1 + threshold
|
|
)
|
|
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
|
|
else:
|
|
|
|
# make everything less than 0.2 be 0.0 and everything else be 1.0
|
|
differential_mask = torch.where(
|
|
differential_mask < threshold,
|
|
torch.zeros_like(differential_mask),
|
|
torch.ones_like(differential_mask)
|
|
)
|
|
return differential_mask
|
|
|
|
|
|
def get_targeted_polarity_loss(
|
|
noisy_latents: torch.Tensor,
|
|
conditional_embeds: PromptEmbeds,
|
|
match_adapter_assist: bool,
|
|
network_weight_list: list,
|
|
timesteps: torch.Tensor,
|
|
pred_kwargs: dict,
|
|
batch: 'DataLoaderBatchDTO',
|
|
noise: torch.Tensor,
|
|
sd: 'StableDiffusion',
|
|
**kwargs
|
|
):
|
|
dtype = get_torch_dtype(sd.torch_dtype)
|
|
device = sd.device_torch
|
|
with torch.no_grad():
|
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
|
|
|
# inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True)
|
|
# noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True)
|
|
differential_scaler = DIFFERENTIAL_SCALER
|
|
|
|
unconditional_diff = (unconditional_latents - conditional_latents)
|
|
unconditional_diff_noise = unconditional_diff * differential_scaler
|
|
conditional_diff = (conditional_latents - unconditional_latents)
|
|
conditional_diff_noise = conditional_diff * differential_scaler
|
|
conditional_diff_noise = conditional_diff_noise.detach().requires_grad_(False)
|
|
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
|
|
#
|
|
baseline_conditional_noisy_latents = sd.add_noise(
|
|
conditional_latents,
|
|
noise,
|
|
timesteps
|
|
).detach()
|
|
|
|
baseline_unconditional_noisy_latents = sd.add_noise(
|
|
unconditional_latents,
|
|
noise,
|
|
timesteps
|
|
).detach()
|
|
|
|
conditional_noise = noise + unconditional_diff_noise
|
|
unconditional_noise = noise + conditional_diff_noise
|
|
|
|
conditional_noisy_latents = sd.add_noise(
|
|
conditional_latents,
|
|
conditional_noise,
|
|
timesteps
|
|
).detach()
|
|
|
|
unconditional_noisy_latents = sd.add_noise(
|
|
unconditional_latents,
|
|
unconditional_noise,
|
|
timesteps
|
|
).detach()
|
|
|
|
# double up everything to run it through all at once
|
|
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
|
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
|
|
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
|
# cat_baseline_noisy_latents = torch.cat(
|
|
# [baseline_conditional_noisy_latents, baseline_unconditional_noisy_latents],
|
|
# dim=0
|
|
# )
|
|
|
|
# Disable the LoRA network so we can predict parent network knowledge without it
|
|
# sd.network.is_active = False
|
|
# sd.unet.eval()
|
|
|
|
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
|
# This acts as our control to preserve the unaltered parts of the image.
|
|
# baseline_prediction = sd.predict_noise(
|
|
# latents=cat_baseline_noisy_latents.to(device, dtype=dtype).detach(),
|
|
# conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
|
|
# timestep=cat_timesteps,
|
|
# guidance_scale=1.0,
|
|
# **pred_kwargs # adapter residuals in here
|
|
# ).detach()
|
|
|
|
# conditional_baseline_prediction, unconditional_baseline_prediction = torch.chunk(baseline_prediction, 2, dim=0)
|
|
|
|
# negative_network_weights = [weight * -1.0 for weight in network_weight_list]
|
|
# positive_network_weights = [weight * 1.0 for weight in network_weight_list]
|
|
# cat_network_weight_list = positive_network_weights + negative_network_weights
|
|
|
|
# turn the LoRA network back on.
|
|
sd.unet.train()
|
|
# sd.network.is_active = True
|
|
|
|
# sd.network.multiplier = cat_network_weight_list
|
|
|
|
# do our prediction with LoRA active on the scaled guidance latents
|
|
prediction = sd.predict_noise(
|
|
latents=cat_latents.to(device, dtype=dtype).detach(),
|
|
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
|
|
timestep=cat_timesteps,
|
|
guidance_scale=1.0,
|
|
**pred_kwargs # adapter residuals in here
|
|
)
|
|
|
|
# prediction = prediction - baseline_prediction
|
|
|
|
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
|
# pred_pos = pred_pos - conditional_baseline_prediction
|
|
# pred_neg = pred_neg - unconditional_baseline_prediction
|
|
|
|
pred_loss = torch.nn.functional.mse_loss(
|
|
pred_pos.float(),
|
|
conditional_noise.float(),
|
|
reduction="none"
|
|
)
|
|
pred_loss = pred_loss.mean([1, 2, 3])
|
|
|
|
pred_neg_loss = torch.nn.functional.mse_loss(
|
|
pred_neg.float(),
|
|
unconditional_noise.float(),
|
|
reduction="none"
|
|
)
|
|
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
|
|
|
loss = pred_loss + pred_neg_loss
|
|
|
|
loss = loss.mean()
|
|
loss.backward()
|
|
|
|
# detach it so parent class can run backward on no grads without throwing error
|
|
loss = loss.detach()
|
|
loss.requires_grad_(True)
|
|
|
|
return loss
|
|
|
|
def get_direct_guidance_loss(
|
|
noisy_latents: torch.Tensor,
|
|
conditional_embeds: 'PromptEmbeds',
|
|
match_adapter_assist: bool,
|
|
network_weight_list: list,
|
|
timesteps: torch.Tensor,
|
|
pred_kwargs: dict,
|
|
batch: 'DataLoaderBatchDTO',
|
|
noise: torch.Tensor,
|
|
sd: 'StableDiffusion',
|
|
unconditional_embeds: Optional[PromptEmbeds] = None,
|
|
mask_multiplier=None,
|
|
prior_pred=None,
|
|
**kwargs
|
|
):
|
|
with torch.no_grad():
|
|
# Perform targeted guidance (working title)
|
|
dtype = get_torch_dtype(sd.torch_dtype)
|
|
device = sd.device_torch
|
|
|
|
|
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
|
|
|
conditional_noisy_latents = sd.add_noise(
|
|
conditional_latents,
|
|
# target_noise,
|
|
noise,
|
|
timesteps
|
|
).detach()
|
|
|
|
unconditional_noisy_latents = sd.add_noise(
|
|
unconditional_latents,
|
|
noise,
|
|
timesteps
|
|
).detach()
|
|
# turn the LoRA network back on.
|
|
sd.unet.train()
|
|
# sd.network.is_active = True
|
|
|
|
# sd.network.multiplier = network_weight_list
|
|
# do our prediction with LoRA active on the scaled guidance latents
|
|
if unconditional_embeds is not None:
|
|
unconditional_embeds = unconditional_embeds.to(device, dtype=dtype).detach()
|
|
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
|
|
|
|
prediction = sd.predict_noise(
|
|
latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(),
|
|
conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(),
|
|
unconditional_embeddings=unconditional_embeds,
|
|
timestep=torch.cat([timesteps, timesteps]),
|
|
guidance_scale=1.0,
|
|
**pred_kwargs # adapter residuals in here
|
|
)
|
|
|
|
noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0)
|
|
|
|
guidance_scale = 1.1
|
|
guidance_pred = noise_pred_uncond + guidance_scale * (
|
|
noise_pred_cond - noise_pred_uncond
|
|
)
|
|
|
|
guidance_loss = torch.nn.functional.mse_loss(
|
|
guidance_pred.float(),
|
|
noise.detach().float(),
|
|
reduction="none"
|
|
)
|
|
if mask_multiplier is not None:
|
|
guidance_loss = guidance_loss * mask_multiplier
|
|
|
|
guidance_loss = guidance_loss.mean([1, 2, 3])
|
|
|
|
guidance_loss = guidance_loss.mean()
|
|
|
|
# loss = guidance_loss + masked_noise_loss
|
|
loss = guidance_loss
|
|
|
|
loss.backward()
|
|
|
|
# detach it so parent class can run backward on no grads without throwing error
|
|
loss = loss.detach()
|
|
loss.requires_grad_(True)
|
|
|
|
return loss
|
|
|
|
|
|
# targeted
|
|
def get_targeted_guidance_loss(
|
|
noisy_latents: torch.Tensor,
|
|
conditional_embeds: 'PromptEmbeds',
|
|
match_adapter_assist: bool,
|
|
network_weight_list: list,
|
|
timesteps: torch.Tensor,
|
|
pred_kwargs: dict,
|
|
batch: 'DataLoaderBatchDTO',
|
|
noise: torch.Tensor,
|
|
sd: 'StableDiffusion',
|
|
**kwargs
|
|
):
|
|
with torch.no_grad():
|
|
dtype = get_torch_dtype(sd.torch_dtype)
|
|
device = sd.device_torch
|
|
|
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
|
|
|
# Encode the unconditional image into latents
|
|
unconditional_noisy_latents = sd.noise_scheduler.add_noise(
|
|
unconditional_latents,
|
|
noise,
|
|
timesteps
|
|
)
|
|
conditional_noisy_latents = sd.noise_scheduler.add_noise(
|
|
conditional_latents,
|
|
noise,
|
|
timesteps
|
|
)
|
|
|
|
# was_network_active = self.network.is_active
|
|
sd.network.is_active = False
|
|
sd.unet.eval()
|
|
|
|
target_differential = unconditional_latents - conditional_latents
|
|
# scale our loss by the differential scaler
|
|
target_differential_abs = target_differential.abs()
|
|
target_differential_abs_min = \
|
|
target_differential_abs.min(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
|
target_differential_abs_max = \
|
|
target_differential_abs.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
|
|
|
min_guidance = 1.0
|
|
max_guidance = 2.0
|
|
|
|
differential_scaler = value_map(
|
|
target_differential_abs,
|
|
target_differential_abs_min,
|
|
target_differential_abs_max,
|
|
min_guidance,
|
|
max_guidance
|
|
).detach()
|
|
|
|
|
|
# With LoRA network bypassed, predict noise to get a baseline of what the network
|
|
# wants to do with the latents + noise. Pass our target latents here for the input.
|
|
target_unconditional = sd.predict_noise(
|
|
latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(),
|
|
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
|
timestep=timesteps,
|
|
guidance_scale=1.0,
|
|
**pred_kwargs # adapter residuals in here
|
|
).detach()
|
|
prior_prediction_loss = torch.nn.functional.mse_loss(
|
|
target_unconditional.float(),
|
|
noise.float(),
|
|
reduction="none"
|
|
).detach().clone()
|
|
|
|
# turn the LoRA network back on.
|
|
sd.unet.train()
|
|
sd.network.is_active = True
|
|
sd.network.multiplier = network_weight_list + [x + -1.0 for x in network_weight_list]
|
|
|
|
# with LoRA active, predict the noise with the scaled differential latents added. This will allow us
|
|
# the opportunity to predict the differential + noise that was added to the latents.
|
|
prediction = sd.predict_noise(
|
|
latents=torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0).to(device, dtype=dtype).detach(),
|
|
conditional_embeddings=concat_prompt_embeds([conditional_embeds, conditional_embeds]).to(device, dtype=dtype).detach(),
|
|
timestep=torch.cat([timesteps, timesteps], dim=0),
|
|
guidance_scale=1.0,
|
|
**pred_kwargs # adapter residuals in here
|
|
)
|
|
|
|
prediction_conditional, prediction_unconditional = torch.chunk(prediction, 2, dim=0)
|
|
|
|
conditional_loss = torch.nn.functional.mse_loss(
|
|
prediction_conditional.float(),
|
|
noise.float(),
|
|
reduction="none"
|
|
)
|
|
|
|
unconditional_loss = torch.nn.functional.mse_loss(
|
|
prediction_unconditional.float(),
|
|
noise.float(),
|
|
reduction="none"
|
|
)
|
|
|
|
positive_loss = torch.abs(
|
|
conditional_loss.float() - prior_prediction_loss.float(),
|
|
)
|
|
# scale our loss by the differential scaler
|
|
positive_loss = positive_loss * differential_scaler
|
|
|
|
positive_loss = positive_loss.mean([1, 2, 3])
|
|
|
|
polar_loss = torch.abs(
|
|
conditional_loss.float() - unconditional_loss.float(),
|
|
).mean([1, 2, 3])
|
|
|
|
|
|
positive_loss = positive_loss.mean() + polar_loss.mean()
|
|
|
|
|
|
positive_loss.backward()
|
|
# loss = positive_loss.detach() + negative_loss.detach()
|
|
loss = positive_loss.detach()
|
|
|
|
# add a grad so other backward does not fail
|
|
loss.requires_grad_(True)
|
|
|
|
# restore network
|
|
sd.network.multiplier = network_weight_list
|
|
|
|
return loss
|
|
|
|
def get_guided_loss_polarity(
|
|
noisy_latents: torch.Tensor,
|
|
conditional_embeds: PromptEmbeds,
|
|
match_adapter_assist: bool,
|
|
network_weight_list: list,
|
|
timesteps: torch.Tensor,
|
|
pred_kwargs: dict,
|
|
batch: 'DataLoaderBatchDTO',
|
|
noise: torch.Tensor,
|
|
sd: 'StableDiffusion',
|
|
train_config: 'TrainConfig',
|
|
scaler=None,
|
|
**kwargs
|
|
):
|
|
dtype = get_torch_dtype(sd.torch_dtype)
|
|
device = sd.device_torch
|
|
with torch.no_grad():
|
|
dtype = get_torch_dtype(dtype)
|
|
noise = noise.to(device, dtype=dtype).detach()
|
|
|
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
|
|
|
target_pos = noise
|
|
target_neg = noise
|
|
|
|
if sd.is_flow_matching:
|
|
linear_timesteps = any([
|
|
train_config.linear_timesteps,
|
|
train_config.linear_timesteps2,
|
|
train_config.timestep_type == 'linear',
|
|
])
|
|
|
|
timestep_type = 'linear' if linear_timesteps else None
|
|
if timestep_type is None:
|
|
timestep_type = train_config.timestep_type
|
|
|
|
sd.noise_scheduler.set_train_timesteps(
|
|
1000,
|
|
device=device,
|
|
timestep_type=timestep_type,
|
|
latents=conditional_latents
|
|
)
|
|
target_pos = (noise - conditional_latents).detach()
|
|
target_neg = (noise - unconditional_latents).detach()
|
|
|
|
conditional_noisy_latents = sd.add_noise(
|
|
conditional_latents,
|
|
noise,
|
|
timesteps
|
|
).detach()
|
|
|
|
unconditional_noisy_latents = sd.add_noise(
|
|
unconditional_latents,
|
|
noise,
|
|
timesteps
|
|
).detach()
|
|
|
|
# double up everything to run it through all at once
|
|
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
|
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
|
|
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
|
|
|
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
|
|
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
|
|
cat_network_weight_list = positive_network_weights + negative_network_weights
|
|
|
|
# turn the LoRA network back on.
|
|
sd.unet.train()
|
|
sd.network.is_active = True
|
|
|
|
sd.network.multiplier = cat_network_weight_list
|
|
|
|
# do our prediction with LoRA active on the scaled guidance latents
|
|
prediction = sd.predict_noise(
|
|
latents=cat_latents.to(device, dtype=dtype).detach(),
|
|
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
|
|
timestep=cat_timesteps,
|
|
guidance_scale=1.0,
|
|
**pred_kwargs # adapter residuals in here
|
|
)
|
|
|
|
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
|
|
|
pred_loss = torch.nn.functional.mse_loss(
|
|
pred_pos.float(),
|
|
target_pos.float(),
|
|
reduction="none"
|
|
)
|
|
# pred_loss = pred_loss.mean([1, 2, 3])
|
|
|
|
pred_neg_loss = torch.nn.functional.mse_loss(
|
|
pred_neg.float(),
|
|
target_neg.float(),
|
|
reduction="none"
|
|
)
|
|
|
|
loss = pred_loss + pred_neg_loss
|
|
|
|
loss = loss.mean([1, 2, 3])
|
|
loss = loss.mean()
|
|
if scaler is not None:
|
|
scaler.scale(loss).backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
# detach it so parent class can run backward on no grads without throwing error
|
|
loss = loss.detach()
|
|
loss.requires_grad_(True)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def get_guided_tnt(
|
|
noisy_latents: torch.Tensor,
|
|
conditional_embeds: PromptEmbeds,
|
|
match_adapter_assist: bool,
|
|
network_weight_list: list,
|
|
timesteps: torch.Tensor,
|
|
pred_kwargs: dict,
|
|
batch: 'DataLoaderBatchDTO',
|
|
noise: torch.Tensor,
|
|
sd: 'StableDiffusion',
|
|
prior_pred: torch.Tensor = None,
|
|
**kwargs
|
|
):
|
|
dtype = get_torch_dtype(sd.torch_dtype)
|
|
device = sd.device_torch
|
|
with torch.no_grad():
|
|
dtype = get_torch_dtype(dtype)
|
|
noise = noise.to(device, dtype=dtype).detach()
|
|
|
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
|
|
|
conditional_noisy_latents = sd.add_noise(
|
|
conditional_latents,
|
|
noise,
|
|
timesteps
|
|
).detach()
|
|
|
|
unconditional_noisy_latents = sd.add_noise(
|
|
unconditional_latents,
|
|
noise,
|
|
timesteps
|
|
).detach()
|
|
|
|
# double up everything to run it through all at once
|
|
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
|
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
|
|
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
|
|
|
|
|
# turn the LoRA network back on.
|
|
sd.unet.train()
|
|
if sd.network is not None:
|
|
cat_network_weight_list = [weight for weight in network_weight_list * 2]
|
|
sd.network.multiplier = cat_network_weight_list
|
|
sd.network.is_active = True
|
|
|
|
|
|
prediction = sd.predict_noise(
|
|
latents=cat_latents.to(device, dtype=dtype).detach(),
|
|
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
|
|
timestep=cat_timesteps,
|
|
guidance_scale=1.0,
|
|
**pred_kwargs # adapter residuals in here
|
|
)
|
|
this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0)
|
|
|
|
this_loss = torch.nn.functional.mse_loss(
|
|
this_prediction.float(),
|
|
noise.float(),
|
|
reduction="none"
|
|
)
|
|
|
|
that_loss = torch.nn.functional.mse_loss(
|
|
that_prediction.float(),
|
|
noise.float(),
|
|
reduction="none"
|
|
)
|
|
|
|
this_loss = this_loss.mean([1, 2, 3])
|
|
# negative loss on that
|
|
that_loss = -that_loss.mean([1, 2, 3])
|
|
|
|
with torch.no_grad():
|
|
# match that loss with this loss so it is not a negative value and same scale
|
|
that_loss_scaler = torch.abs(this_loss) / torch.abs(that_loss)
|
|
|
|
that_loss = that_loss * that_loss_scaler * 0.01
|
|
|
|
loss = this_loss + that_loss
|
|
|
|
loss = loss.mean()
|
|
|
|
loss.backward()
|
|
|
|
# detach it so parent class can run backward on no grads without throwing error
|
|
loss = loss.detach()
|
|
loss.requires_grad_(True)
|
|
|
|
return loss
|
|
|
|
def targeted_flow_guidance(
|
|
noisy_latents: torch.Tensor,
|
|
conditional_embeds: 'PromptEmbeds',
|
|
match_adapter_assist: bool,
|
|
network_weight_list: list,
|
|
timesteps: torch.Tensor,
|
|
pred_kwargs: dict,
|
|
batch: 'DataLoaderBatchDTO',
|
|
noise: torch.Tensor,
|
|
sd: 'StableDiffusion',
|
|
unconditional_embeds: Optional[PromptEmbeds] = None,
|
|
mask_multiplier=None,
|
|
prior_pred=None,
|
|
scaler=None,
|
|
train_config=None,
|
|
**kwargs
|
|
):
|
|
if not sd.is_flow_matching:
|
|
raise ValueError("targeted_flow only works on flow matching models")
|
|
dtype = get_torch_dtype(sd.torch_dtype)
|
|
device = sd.device_torch
|
|
with torch.no_grad():
|
|
dtype = get_torch_dtype(dtype)
|
|
noise = noise.to(device, dtype=dtype).detach()
|
|
|
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
|
|
|
# get noisy latents for both conditional and unconditional predictions
|
|
unconditional_noisy_latents = sd.add_noise(
|
|
unconditional_latents,
|
|
noise,
|
|
timesteps
|
|
).detach()
|
|
conditional_noisy_latents = sd.add_noise(
|
|
conditional_latents,
|
|
noise,
|
|
timesteps
|
|
).detach()
|
|
|
|
# disable the lora to get a baseline prediction
|
|
sd.network.is_active = False
|
|
sd.unet.eval()
|
|
|
|
# get a baseline prediction of the model knowledge without the lora network
|
|
# we do this with the unconditional noisy latents
|
|
baseline_prediction = sd.predict_noise(
|
|
latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(),
|
|
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
|
timestep=timesteps,
|
|
guidance_scale=1.0,
|
|
**pred_kwargs
|
|
).detach()
|
|
|
|
# This is our normal flowmatching target
|
|
# target = noise - latents
|
|
# we need to target the baseline noise but with our conditional latents
|
|
# to do this we first have to determine the baseline_prediction noise by reversing the flowmatching target
|
|
baseline_predicted_noise = baseline_prediction + unconditional_latents
|
|
|
|
# baseline_predicted_noise is now the noise prediction our model would make with a the unconditional image.
|
|
# we use this as our new noise target to preserve the existing knowledge of the image
|
|
target_noise = baseline_predicted_noise
|
|
|
|
# compute our new target prediction using our current knowledge noise with our conditional latents
|
|
# this makes it so the only new information is the differential of our conditional and unconditional latents
|
|
# forcing the network to preserve existing knowledge, but learn only our changes
|
|
target_pred = (target_noise - conditional_latents).detach()
|
|
|
|
# make a prediction with the lora network active
|
|
sd.unet.train()
|
|
sd.network.is_active = True
|
|
sd.network.multiplier = network_weight_list
|
|
prediction = sd.predict_noise(
|
|
latents=conditional_noisy_latents.to(device, dtype=dtype).detach(),
|
|
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
|
timestep=timesteps,
|
|
guidance_scale=1.0,
|
|
**pred_kwargs
|
|
)
|
|
|
|
# target our baseline + diffirential noise target
|
|
pred_loss = torch.nn.functional.mse_loss(
|
|
prediction.float(),
|
|
target_pred.float()
|
|
)
|
|
|
|
return pred_loss
|
|
|
|
|
|
# this processes all guidance losses based on the batch information
|
|
def get_guidance_loss(
|
|
noisy_latents: torch.Tensor,
|
|
conditional_embeds: 'PromptEmbeds',
|
|
match_adapter_assist: bool,
|
|
network_weight_list: list,
|
|
timesteps: torch.Tensor,
|
|
pred_kwargs: dict,
|
|
batch: 'DataLoaderBatchDTO',
|
|
noise: torch.Tensor,
|
|
sd: 'StableDiffusion',
|
|
unconditional_embeds: Optional[PromptEmbeds] = None,
|
|
mask_multiplier=None,
|
|
prior_pred=None,
|
|
scaler=None,
|
|
train_config=None,
|
|
**kwargs
|
|
):
|
|
# TODO add others and process individual batch items separately
|
|
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
|
|
|
|
if guidance_type == "targeted":
|
|
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted guidance"
|
|
return get_targeted_guidance_loss(
|
|
noisy_latents,
|
|
conditional_embeds,
|
|
match_adapter_assist,
|
|
network_weight_list,
|
|
timesteps,
|
|
pred_kwargs,
|
|
batch,
|
|
noise,
|
|
sd,
|
|
**kwargs
|
|
)
|
|
elif guidance_type == "polarity":
|
|
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
|
|
return get_guided_loss_polarity(
|
|
noisy_latents,
|
|
conditional_embeds,
|
|
match_adapter_assist,
|
|
network_weight_list,
|
|
timesteps,
|
|
pred_kwargs,
|
|
batch,
|
|
noise,
|
|
sd,
|
|
scaler=scaler,
|
|
train_config=train_config,
|
|
**kwargs
|
|
)
|
|
elif guidance_type == "tnt":
|
|
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
|
|
return get_guided_tnt(
|
|
noisy_latents,
|
|
conditional_embeds,
|
|
match_adapter_assist,
|
|
network_weight_list,
|
|
timesteps,
|
|
pred_kwargs,
|
|
batch,
|
|
noise,
|
|
sd,
|
|
prior_pred=prior_pred,
|
|
**kwargs
|
|
)
|
|
|
|
elif guidance_type == "targeted_polarity":
|
|
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance"
|
|
return get_targeted_polarity_loss(
|
|
noisy_latents,
|
|
conditional_embeds,
|
|
match_adapter_assist,
|
|
network_weight_list,
|
|
timesteps,
|
|
pred_kwargs,
|
|
batch,
|
|
noise,
|
|
sd,
|
|
**kwargs
|
|
)
|
|
elif guidance_type == "direct":
|
|
return get_direct_guidance_loss(
|
|
noisy_latents,
|
|
conditional_embeds,
|
|
match_adapter_assist,
|
|
network_weight_list,
|
|
timesteps,
|
|
pred_kwargs,
|
|
batch,
|
|
noise,
|
|
sd,
|
|
unconditional_embeds=unconditional_embeds,
|
|
mask_multiplier=mask_multiplier,
|
|
prior_pred=prior_pred,
|
|
**kwargs
|
|
)
|
|
elif guidance_type == "targeted_flow":
|
|
return targeted_flow_guidance(
|
|
noisy_latents,
|
|
conditional_embeds,
|
|
match_adapter_assist,
|
|
network_weight_list,
|
|
timesteps,
|
|
pred_kwargs,
|
|
batch,
|
|
noise,
|
|
sd,
|
|
unconditional_embeds=unconditional_embeds,
|
|
mask_multiplier=mask_multiplier,
|
|
prior_pred=prior_pred,
|
|
scaler=scaler,
|
|
train_config=train_config,
|
|
**kwargs
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Guidance type {guidance_type} is not implemented")
|