Files
ai-toolkit/toolkit/guidance.py

653 lines
23 KiB
Python

import torch
from typing import Literal
from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import get_torch_dtype
GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"]
DIFFERENTIAL_SCALER = 0.2
# DIFFERENTIAL_SCALER = 0.25
def get_differential_mask(
conditional_latents: torch.Tensor,
unconditional_latents: torch.Tensor,
threshold: float = 0.2,
gradient: bool = False,
):
# make a differential mask
differential_mask = torch.abs(conditional_latents - unconditional_latents)
max_differential = \
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
differential_scaler = 1.0 / max_differential
differential_mask = differential_mask * differential_scaler
if gradient:
# wew need to scale it to 0-1
# differential_mask = differential_mask - differential_mask.min()
# differential_mask = differential_mask / differential_mask.max()
# add 0.2 threshold to both sides and clip
differential_mask = value_map(
differential_mask,
differential_mask.min(),
differential_mask.max(),
0 - threshold,
1 + threshold
)
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
else:
# make everything less than 0.2 be 0.0 and everything else be 1.0
differential_mask = torch.where(
differential_mask < threshold,
torch.zeros_like(differential_mask),
torch.ones_like(differential_mask)
)
return differential_mask
def get_targeted_polarity_loss(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
with torch.no_grad():
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
# inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True)
# noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True)
differential_scaler = DIFFERENTIAL_SCALER
unconditional_diff = (unconditional_latents - conditional_latents)
unconditional_diff_noise = unconditional_diff * differential_scaler
conditional_diff = (conditional_latents - unconditional_latents)
conditional_diff_noise = conditional_diff * differential_scaler
conditional_diff_noise = conditional_diff_noise.detach().requires_grad_(False)
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
#
baseline_conditional_noisy_latents = sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
baseline_unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
conditional_noise = noise + unconditional_diff_noise
unconditional_noise = noise + conditional_diff_noise
conditional_noisy_latents = sd.add_noise(
conditional_latents,
conditional_noise,
timesteps
).detach()
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
unconditional_noise,
timesteps
).detach()
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
# cat_baseline_noisy_latents = torch.cat(
# [baseline_conditional_noisy_latents, baseline_unconditional_noisy_latents],
# dim=0
# )
# Disable the LoRA network so we can predict parent network knowledge without it
# sd.network.is_active = False
# sd.unet.eval()
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
# baseline_prediction = sd.predict_noise(
# latents=cat_baseline_noisy_latents.to(device, dtype=dtype).detach(),
# conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
# timestep=cat_timesteps,
# guidance_scale=1.0,
# **pred_kwargs # adapter residuals in here
# ).detach()
# conditional_baseline_prediction, unconditional_baseline_prediction = torch.chunk(baseline_prediction, 2, dim=0)
# negative_network_weights = [weight * -1.0 for weight in network_weight_list]
# positive_network_weights = [weight * 1.0 for weight in network_weight_list]
# cat_network_weight_list = positive_network_weights + negative_network_weights
# turn the LoRA network back on.
sd.unet.train()
# sd.network.is_active = True
# sd.network.multiplier = cat_network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = sd.predict_noise(
latents=cat_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
# prediction = prediction - baseline_prediction
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
# pred_pos = pred_pos - conditional_baseline_prediction
# pred_neg = pred_neg - unconditional_baseline_prediction
pred_loss = torch.nn.functional.mse_loss(
pred_pos.float(),
conditional_noise.float(),
reduction="none"
)
pred_loss = pred_loss.mean([1, 2, 3])
pred_neg_loss = torch.nn.functional.mse_loss(
pred_neg.float(),
unconditional_noise.float(),
reduction="none"
)
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
loss = pred_loss + pred_neg_loss
loss = loss.mean()
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
def get_direct_guidance_loss(
noisy_latents: torch.Tensor,
conditional_embeds: 'PromptEmbeds',
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
conditional_noisy_latents = sd.add_noise(
conditional_latents,
# target_noise,
noise,
timesteps
).detach()
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
# turn the LoRA network back on.
sd.unet.train()
# sd.network.is_active = True
# sd.network.multiplier = network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = sd.predict_noise(
latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(),
conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(),
timestep=torch.cat([timesteps, timesteps]),
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0)
guidance_scale = 1.0
guidance_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
guidance_loss = torch.nn.functional.mse_loss(
guidance_pred.float(),
noise.detach().float(),
reduction="none"
)
guidance_loss = guidance_loss.mean([1, 2, 3])
guidance_loss = guidance_loss.mean()
# loss = guidance_loss + masked_noise_loss
loss = guidance_loss
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
# targeted
def get_targeted_guidance_loss(
noisy_latents: torch.Tensor,
conditional_embeds: 'PromptEmbeds',
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
with torch.no_grad():
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
# create the differential mask from the actual tensors
conditional_imgs = batch.tensor.to(device, dtype=dtype).detach()
unconditional_imgs = batch.unconditional_tensor.to(device, dtype=dtype).detach()
differential_mask = torch.abs(conditional_imgs - unconditional_imgs)
differential_mask = differential_mask - differential_mask.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=3, keepdim=True)[0]
differential_mask = differential_mask / differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
# differential_mask is (bs, 3, width, height)
# latents are (bs, 4, width, height)
# reduce the mean on dim 1 to get a single channel mask and stack it to match latents
differential_mask = differential_mask.mean(dim=1, keepdim=True)
differential_mask = torch.cat([differential_mask] * 4, dim=1)
# scale the mask down to latent size
differential_mask = torch.nn.functional.interpolate(
differential_mask,
size=noisy_latents.shape[2:],
mode="nearest"
)
conditional_noisy_latents = noisy_latents
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
# unconditional_as_noise = unconditional_latents - conditional_latents
# conditional_as_noise = conditional_latents - unconditional_latents
# Encode the unconditional image into latents
unconditional_noisy_latents = sd.noise_scheduler.add_noise(
unconditional_latents,
noise,
timesteps
)
conditional_noisy_latents = sd.noise_scheduler.add_noise(
conditional_latents,
noise,
timesteps
)
# was_network_active = self.network.is_active
sd.network.is_active = False
sd.unet.eval()
# calculate the differential between our conditional (target image) and out unconditional ("bad" image)
# target_differential = unconditional_noisy_latents - conditional_noisy_latents
target_differential = unconditional_latents - conditional_latents
# target_differential = conditional_latents - unconditional_latents
# scale the target differential by the scheduler
# todo, scale it the right way
# target_differential = sd.noise_scheduler.add_noise(
# torch.zeros_like(target_differential),
# target_differential,
# timesteps
# )
# noise_abs_mean = torch.abs(noise + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
# target_differential = target_differential.detach()
# target_differential_abs_mean = torch.abs(target_differential + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
# # determins scaler to adjust to same abs mean as noise
# scaler = noise_abs_mean / target_differential_abs_mean
target_differential_knowledge = target_differential
target_differential_knowledge = target_differential_knowledge.detach()
# add the target differential to the target latents as if it were noise with the scheduler scaled to
# the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done
# properly
# guidance_latents = sd.noise_scheduler.add_noise(
# conditional_noisy_latents,
# target_differential,
# timesteps
# )
# guidance_latents = conditional_noisy_latents + target_differential
# target_noise = conditional_noisy_latents + target_differential
# With LoRA network bypassed, predict noise to get a baseline of what the network
# wants to do with the latents + noise. Pass our target latents here for the input.
target_unconditional = sd.predict_noise(
latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
# target_conditional = sd.predict_noise(
# latents=conditional_noisy_latents.to(device, dtype=dtype).detach(),
# conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
# timestep=timesteps,
# guidance_scale=1.0,
# **pred_kwargs # adapter residuals in here
# ).detach()
# we calculate the networks current knowledge so we do not overlearn what we know
# parent_knowledge = target_unconditional - target_conditional
# parent_knowledge = parent_knowledge.detach()
# del target_conditional
# del target_unconditional
# we now have the differential noise prediction needed to create our convergence target
# target_unknown_knowledge = target_differential + parent_knowledge
# del parent_knowledge
prior_prediction_loss = torch.nn.functional.mse_loss(
target_unconditional.float(),
noise.float(),
reduction="none"
).detach().clone()
# turn the LoRA network back on.
sd.unet.train()
sd.network.is_active = True
sd.network.multiplier = network_weight_list
# with LoRA active, predict the noise with the scaled differential latents added. This will allow us
# the opportunity to predict the differential + noise that was added to the latents.
prediction_conditional = sd.predict_noise(
latents=conditional_noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
# remove the baseline conditional prediction. This will leave only the divergence from the baseline and
# the prediction of the added differential noise
# prediction_positive = prediction_unconditional - target_unconditional
# current_knowledge = target_unconditional - prediction_conditional
# current_differential_knowledge = prediction_conditional - target_unconditional
# current_unknown_knowledge = parent_knowledge - current_knowledge
#
# current_unknown_knowledge_abs_mean = torch.abs(current_unknown_knowledge + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
# current_unknown_knowledge_std = current_unknown_knowledge / current_unknown_knowledge_abs_mean
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
# this is the diffusion training process.
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
# differential between the conditional and unconditional images
# positive_loss = torch.nn.functional.mse_loss(
# current_differential_knowledge.float(),
# target_differential_knowledge.float(),
# reduction="none"
# )
normal_loss = torch.nn.functional.mse_loss(
prediction_conditional.float(),
noise.float(),
reduction="none"
)
#
# # scale positive and neutral loss to the same scale
# positive_loss_abs_mean = torch.abs(positive_loss + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
# normal_loss_abs_mean = torch.abs(normal_loss + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
# scaler = normal_loss_abs_mean / positive_loss_abs_mean
# positive_loss = positive_loss * scaler
# positive_loss = positive_loss * differential_mask
# positive_loss = positive_loss
# masked_normal_loss = normal_loss * differential_mask
prior_loss = torch.abs(
normal_loss.float() - prior_prediction_loss.float(),
# ) * (1 - differential_mask)
)
decouple = True
# positive_loss_full = positive_loss
# prior_loss_full = prior_loss
#
# current_scaler = (prior_loss_full.max() / positive_loss_full.max())
# # positive_loss = positive_loss * current_scaler
# avg_scaler_arr.append(current_scaler.item())
# avg_scaler = sum(avg_scaler_arr) / len(avg_scaler_arr)
# print(f"avg scaler: {avg_scaler}, current scaler: {current_scaler.item()}")
# # remove extra scalers more than 100
# if len(avg_scaler_arr) > 100:
# avg_scaler_arr.pop(0)
#
# # positive_loss = positive_loss * avg_scaler
# positive_loss = positive_loss * avg_scaler * 0.1
if decouple:
# positive_loss = positive_loss.mean([1, 2, 3])
prior_loss = prior_loss.mean([1, 2, 3])
# masked_normal_loss = masked_normal_loss.mean([1, 2, 3])
positive_loss = prior_loss
# positive_loss = positive_loss + prior_loss
else:
# positive_loss = positive_loss + prior_loss
positive_loss = prior_loss
positive_loss = positive_loss.mean([1, 2, 3])
# positive_loss = positive_loss + adain_loss.mean([1, 2, 3])
# send it backwards BEFORE switching network polarity
# positive_loss = self.apply_snr(positive_loss, timesteps)
positive_loss = positive_loss.mean()
positive_loss.backward()
# loss = positive_loss.detach() + negative_loss.detach()
loss = positive_loss.detach()
# add a grad so other backward does not fail
loss.requires_grad_(True)
# restore network
sd.network.multiplier = network_weight_list
return loss
def get_guided_loss_polarity(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
with torch.no_grad():
dtype = get_torch_dtype(dtype)
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
conditional_noisy_latents = sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
cat_network_weight_list = positive_network_weights + negative_network_weights
# turn the LoRA network back on.
sd.unet.train()
sd.network.is_active = True
sd.network.multiplier = cat_network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = sd.predict_noise(
latents=cat_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
pred_loss = torch.nn.functional.mse_loss(
pred_pos.float(),
noise.float(),
reduction="none"
)
# pred_loss = pred_loss.mean([1, 2, 3])
pred_neg_loss = torch.nn.functional.mse_loss(
pred_neg.float(),
noise.float(),
reduction="none"
)
loss = pred_loss + pred_neg_loss
loss = loss.mean([1, 2, 3])
loss = loss.mean()
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
# this processes all guidance losses based on the batch information
def get_guidance_loss(
noisy_latents: torch.Tensor,
conditional_embeds: 'PromptEmbeds',
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
# TODO add others and process individual batch items separately
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
if guidance_type == "targeted":
return get_targeted_guidance_loss(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
**kwargs
)
elif guidance_type == "polarity":
return get_guided_loss_polarity(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
**kwargs
)
elif guidance_type == "targeted_polarity":
return get_targeted_polarity_loss(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
**kwargs
)
elif guidance_type == "direct":
return get_direct_guidance_loss(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
**kwargs
)
else:
raise NotImplementedError(f"Guidance type {guidance_type} is not implemented")