mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 01:39:20 +00:00
Added a polarity balancer to guidance
This commit is contained in:
@@ -276,34 +276,9 @@ def get_targeted_guidance_loss(
|
||||
dtype = get_torch_dtype(sd.torch_dtype)
|
||||
device = sd.device_torch
|
||||
|
||||
# create the differential mask from the actual tensors
|
||||
conditional_imgs = batch.tensor.to(device, dtype=dtype).detach()
|
||||
unconditional_imgs = batch.unconditional_tensor.to(device, dtype=dtype).detach()
|
||||
differential_mask = torch.abs(conditional_imgs - unconditional_imgs)
|
||||
differential_mask = differential_mask - differential_mask.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=3, keepdim=True)[0]
|
||||
differential_mask = differential_mask / differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
||||
|
||||
# differential_mask is (bs, 3, width, height)
|
||||
# latents are (bs, 4, width, height)
|
||||
# reduce the mean on dim 1 to get a single channel mask and stack it to match latents
|
||||
differential_mask = differential_mask.mean(dim=1, keepdim=True)
|
||||
differential_mask = torch.cat([differential_mask] * 4, dim=1)
|
||||
|
||||
# scale the mask down to latent size
|
||||
differential_mask = torch.nn.functional.interpolate(
|
||||
differential_mask,
|
||||
size=noisy_latents.shape[2:],
|
||||
mode="nearest"
|
||||
)
|
||||
|
||||
conditional_noisy_latents = noisy_latents
|
||||
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
# unconditional_as_noise = unconditional_latents - conditional_latents
|
||||
# conditional_as_noise = conditional_latents - unconditional_latents
|
||||
|
||||
# Encode the unconditional image into latents
|
||||
unconditional_noisy_latents = sd.noise_scheduler.add_noise(
|
||||
unconditional_latents,
|
||||
@@ -320,42 +295,25 @@ def get_targeted_guidance_loss(
|
||||
sd.network.is_active = False
|
||||
sd.unet.eval()
|
||||
|
||||
|
||||
# calculate the differential between our conditional (target image) and out unconditional ("bad" image)
|
||||
# target_differential = unconditional_noisy_latents - conditional_noisy_latents
|
||||
target_differential = unconditional_latents - conditional_latents
|
||||
# target_differential = conditional_latents - unconditional_latents
|
||||
# scale our loss by the differential scaler
|
||||
target_differential_abs = target_differential.abs()
|
||||
target_differential_abs_min = \
|
||||
target_differential_abs.min(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
||||
target_differential_abs_max = \
|
||||
target_differential_abs.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
||||
|
||||
# scale the target differential by the scheduler
|
||||
# todo, scale it the right way
|
||||
# target_differential = sd.noise_scheduler.add_noise(
|
||||
# torch.zeros_like(target_differential),
|
||||
# target_differential,
|
||||
# timesteps
|
||||
# )
|
||||
min_guidance = 1.0
|
||||
max_guidance = 2.0
|
||||
|
||||
# noise_abs_mean = torch.abs(noise + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
|
||||
differential_scaler = value_map(
|
||||
target_differential_abs,
|
||||
target_differential_abs_min,
|
||||
target_differential_abs_max,
|
||||
min_guidance,
|
||||
max_guidance
|
||||
).detach()
|
||||
|
||||
# target_differential = target_differential.detach()
|
||||
# target_differential_abs_mean = torch.abs(target_differential + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
|
||||
# # determins scaler to adjust to same abs mean as noise
|
||||
# scaler = noise_abs_mean / target_differential_abs_mean
|
||||
|
||||
|
||||
target_differential_knowledge = target_differential
|
||||
target_differential_knowledge = target_differential_knowledge.detach()
|
||||
|
||||
# add the target differential to the target latents as if it were noise with the scheduler scaled to
|
||||
# the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done
|
||||
# properly
|
||||
# guidance_latents = sd.noise_scheduler.add_noise(
|
||||
# conditional_noisy_latents,
|
||||
# target_differential,
|
||||
# timesteps
|
||||
# )
|
||||
|
||||
# guidance_latents = conditional_noisy_latents + target_differential
|
||||
# target_noise = conditional_noisy_latents + target_differential
|
||||
|
||||
# With LoRA network bypassed, predict noise to get a baseline of what the network
|
||||
# wants to do with the latents + noise. Pass our target latents here for the input.
|
||||
@@ -366,23 +324,6 @@ def get_targeted_guidance_loss(
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
# target_conditional = sd.predict_noise(
|
||||
# latents=conditional_noisy_latents.to(device, dtype=dtype).detach(),
|
||||
# conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||
# timestep=timesteps,
|
||||
# guidance_scale=1.0,
|
||||
# **pred_kwargs # adapter residuals in here
|
||||
# ).detach()
|
||||
|
||||
# we calculate the networks current knowledge so we do not overlearn what we know
|
||||
# parent_knowledge = target_unconditional - target_conditional
|
||||
# parent_knowledge = parent_knowledge.detach()
|
||||
# del target_conditional
|
||||
# del target_unconditional
|
||||
|
||||
# we now have the differential noise prediction needed to create our convergence target
|
||||
# target_unknown_knowledge = target_differential + parent_knowledge
|
||||
# del parent_knowledge
|
||||
prior_prediction_loss = torch.nn.functional.mse_loss(
|
||||
target_unconditional.float(),
|
||||
noise.float(),
|
||||
@@ -392,97 +333,48 @@ def get_targeted_guidance_loss(
|
||||
# turn the LoRA network back on.
|
||||
sd.unet.train()
|
||||
sd.network.is_active = True
|
||||
sd.network.multiplier = network_weight_list
|
||||
sd.network.multiplier = network_weight_list + [x + -1.0 for x in network_weight_list]
|
||||
|
||||
# with LoRA active, predict the noise with the scaled differential latents added. This will allow us
|
||||
# the opportunity to predict the differential + noise that was added to the latents.
|
||||
prediction_conditional = sd.predict_noise(
|
||||
latents=conditional_noisy_latents.to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
prediction = sd.predict_noise(
|
||||
latents=torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0).to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=concat_prompt_embeds([conditional_embeds, conditional_embeds]).to(device, dtype=dtype).detach(),
|
||||
timestep=torch.cat([timesteps, timesteps], dim=0),
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
prediction_conditional, prediction_unconditional = torch.chunk(prediction, 2, dim=0)
|
||||
|
||||
|
||||
# remove the baseline conditional prediction. This will leave only the divergence from the baseline and
|
||||
# the prediction of the added differential noise
|
||||
# prediction_positive = prediction_unconditional - target_unconditional
|
||||
# current_knowledge = target_unconditional - prediction_conditional
|
||||
# current_differential_knowledge = prediction_conditional - target_unconditional
|
||||
|
||||
# current_unknown_knowledge = parent_knowledge - current_knowledge
|
||||
#
|
||||
# current_unknown_knowledge_abs_mean = torch.abs(current_unknown_knowledge + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
|
||||
# current_unknown_knowledge_std = current_unknown_knowledge / current_unknown_knowledge_abs_mean
|
||||
|
||||
|
||||
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
|
||||
# this is the diffusion training process.
|
||||
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
|
||||
# differential between the conditional and unconditional images
|
||||
|
||||
# positive_loss = torch.nn.functional.mse_loss(
|
||||
# current_differential_knowledge.float(),
|
||||
# target_differential_knowledge.float(),
|
||||
# reduction="none"
|
||||
# )
|
||||
|
||||
normal_loss = torch.nn.functional.mse_loss(
|
||||
conditional_loss = torch.nn.functional.mse_loss(
|
||||
prediction_conditional.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
#
|
||||
# # scale positive and neutral loss to the same scale
|
||||
# positive_loss_abs_mean = torch.abs(positive_loss + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
|
||||
# normal_loss_abs_mean = torch.abs(normal_loss + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
|
||||
# scaler = normal_loss_abs_mean / positive_loss_abs_mean
|
||||
# positive_loss = positive_loss * scaler
|
||||
|
||||
# positive_loss = positive_loss * differential_mask
|
||||
# positive_loss = positive_loss
|
||||
# masked_normal_loss = normal_loss * differential_mask
|
||||
|
||||
prior_loss = torch.abs(
|
||||
normal_loss.float() - prior_prediction_loss.float(),
|
||||
# ) * (1 - differential_mask)
|
||||
unconditional_loss = torch.nn.functional.mse_loss(
|
||||
prediction_unconditional.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
|
||||
decouple = True
|
||||
positive_loss = torch.abs(
|
||||
conditional_loss.float() - prior_prediction_loss.float(),
|
||||
)
|
||||
# scale our loss by the differential scaler
|
||||
positive_loss = positive_loss * differential_scaler
|
||||
|
||||
# positive_loss_full = positive_loss
|
||||
# prior_loss_full = prior_loss
|
||||
#
|
||||
# current_scaler = (prior_loss_full.max() / positive_loss_full.max())
|
||||
# # positive_loss = positive_loss * current_scaler
|
||||
# avg_scaler_arr.append(current_scaler.item())
|
||||
# avg_scaler = sum(avg_scaler_arr) / len(avg_scaler_arr)
|
||||
# print(f"avg scaler: {avg_scaler}, current scaler: {current_scaler.item()}")
|
||||
# # remove extra scalers more than 100
|
||||
# if len(avg_scaler_arr) > 100:
|
||||
# avg_scaler_arr.pop(0)
|
||||
#
|
||||
# # positive_loss = positive_loss * avg_scaler
|
||||
# positive_loss = positive_loss * avg_scaler * 0.1
|
||||
positive_loss = positive_loss.mean([1, 2, 3])
|
||||
|
||||
if decouple:
|
||||
# positive_loss = positive_loss.mean([1, 2, 3])
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
# masked_normal_loss = masked_normal_loss.mean([1, 2, 3])
|
||||
positive_loss = prior_loss
|
||||
# positive_loss = positive_loss + prior_loss
|
||||
else:
|
||||
polar_loss = torch.abs(
|
||||
conditional_loss.float() - unconditional_loss.float(),
|
||||
).mean([1, 2, 3])
|
||||
|
||||
|
||||
positive_loss = positive_loss.mean() + polar_loss.mean()
|
||||
|
||||
# positive_loss = positive_loss + prior_loss
|
||||
positive_loss = prior_loss
|
||||
positive_loss = positive_loss.mean([1, 2, 3])
|
||||
|
||||
# positive_loss = positive_loss + adain_loss.mean([1, 2, 3])
|
||||
# send it backwards BEFORE switching network polarity
|
||||
# positive_loss = self.apply_snr(positive_loss, timesteps)
|
||||
positive_loss = positive_loss.mean()
|
||||
positive_loss.backward()
|
||||
# loss = positive_loss.detach() + negative_loss.detach()
|
||||
loss = positive_loss.detach()
|
||||
|
||||
Reference in New Issue
Block a user