Tons of bug fixes and improvements to special training. Fixed slider training.

This commit is contained in:
Jaret Burkett
2023-12-09 16:38:10 -07:00
parent eaec2f5a52
commit eaa0fb6253
9 changed files with 639 additions and 74 deletions

View File

@@ -1,5 +1,7 @@
import torch
from typing import Literal
from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion
@@ -8,13 +10,16 @@ from toolkit.train_tools import get_torch_dtype
GuidanceType = Literal["targeted", "polarity", "targeted_polarity"]
DIFFERENTIAL_SCALER = 0.2
# DIFFERENTIAL_SCALER = 0.25
def get_differential_mask(
conditional_latents: torch.Tensor,
unconditional_latents: torch.Tensor,
threshold: float = 0.2
threshold: float = 0.2,
gradient: bool = False,
):
# make a differential mask
differential_mask = torch.abs(conditional_latents - unconditional_latents)
@@ -23,12 +28,27 @@ def get_differential_mask(
differential_scaler = 1.0 / max_differential
differential_mask = differential_mask * differential_scaler
# make everything less than 0.2 be 0.0 and everything else be 1.0
differential_mask = torch.where(
differential_mask < threshold,
torch.zeros_like(differential_mask),
torch.ones_like(differential_mask)
)
if gradient:
# wew need to scale it to 0-1
# differential_mask = differential_mask - differential_mask.min()
# differential_mask = differential_mask / differential_mask.max()
# add 0.2 threshold to both sides and clip
differential_mask = value_map(
differential_mask,
differential_mask.min(),
differential_mask.max(),
0 - threshold,
1 + threshold
)
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
else:
# make everything less than 0.2 be 0.0 and everything else be 1.0
differential_mask = torch.where(
differential_mask < threshold,
torch.zeros_like(differential_mask),
torch.ones_like(differential_mask)
)
return differential_mask
@@ -47,7 +67,6 @@ def get_targeted_polarity_loss(
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
with torch.no_grad():
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
@@ -164,7 +183,7 @@ def get_targeted_polarity_loss(
return loss
# This targets only the positive differential
# targeted
def get_targeted_guidance_loss(
noisy_latents: torch.Tensor,
@@ -183,35 +202,71 @@ def get_targeted_guidance_loss(
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
unconditional_diff = (unconditional_latents - conditional_latents)
# apply random offset to both latents
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
offset = offset * 0.1
conditional_latents = conditional_latents + offset
unconditional_latents = unconditional_latents + offset
# get random scale 0f 0.8 to 1.2
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
scale = scale * 0.4
scale = scale + 0.8
conditional_latents = conditional_latents * scale
unconditional_latents = unconditional_latents * scale
diff_mask = get_differential_mask(
conditional_latents,
unconditional_latents,
threshold=0.1
threshold=0.2,
gradient=True
)
# this is a magic number I spent weeks deducing. It works and I have no idea why.
# unconditional_diff_noise = unconditional_diff * DIFFERENTIAL_SCALER
# standardize inpute to std of 1
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
#
# # scale the latents to std of 1
# conditional_latents = conditional_latents / combo_std
# unconditional_latents = unconditional_latents / combo_std
inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True)
noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True)
diff_noise_scaler = noise_abs_mean / inputs_abs_mean
unconditional_diff_noise = unconditional_diff * diff_noise_scaler
unconditional_diff = (unconditional_latents - conditional_latents)
# get a -0.5 to 0.5 multiplier for the diff noise
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
# noise_multiplier = noise_multiplier - 0.5
noise_multiplier = 1.0
# unconditional_diff_noise = unconditional_diff * noise_multiplier
unconditional_diff_noise = unconditional_diff * noise_multiplier
# scale it to the timestep
unconditional_diff_noise = sd.add_noise(
torch.zeros_like(unconditional_latents),
unconditional_diff_noise,
timesteps
)
unconditional_diff_noise = unconditional_diff_noise * 0.2
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
baseline_noisy_latents = sd.add_noise(
conditional_latents,
unconditional_latents,
noise,
timesteps
).detach()
target_noise = noise + unconditional_diff_noise
noisy_latents = sd.add_noise(
conditional_latents,
# noise + unconditional_diff_noise,
noise,
target_noise,
# noise,
timesteps
).detach()
# Disable the LoRA network so we can predict parent network knowledge without it
@@ -226,15 +281,20 @@ def get_targeted_guidance_loss(
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
).detach().requires_grad_(False)
# determine the error for the baseline prediction
baseline_prediction_error = baseline_prediction - noise
# turn the LoRA network back on.
sd.unet.train()
sd.network.is_active = True
sd.network.multiplier = network_weight_list
unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
masked_noise = noise * diff_mask
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
# masked_noise = noise * diff_mask
# pred_target = unmasked_noise + unconditional_diff_noise
# do our prediction with LoRA active on the scaled guidance latents
@@ -246,30 +306,32 @@ def get_targeted_guidance_loss(
**pred_kwargs # adapter residuals in here
)
prediction = prediction - unmasked_baseline_prediction
# prediction = prediction - baseline_prediction
baseline_loss = torch.nn.functional.mse_loss(
baseline_prediction.float(),
noise.float(),
baselined_prediction = prediction - baseline_prediction
guidance_loss = torch.nn.functional.mse_loss(
baselined_prediction.float(),
# unconditional_diff_noise.float(),
unconditional_diff_noise.float(),
reduction="none"
)
baseline_loss = baseline_loss * (1.0 - diff_mask)
baseline_loss = baseline_loss.mean([1, 2, 3])
guidance_loss = guidance_loss.mean([1, 2, 3])
# loss = torch.nn.functional.l1_loss(
loss = torch.nn.functional.mse_loss(
guidance_loss = guidance_loss.mean()
# do the masked noise prediction
masked_noise_loss = torch.nn.functional.mse_loss(
prediction.float(),
masked_noise.float(),
target_noise.float(),
reduction="none"
)
loss = loss * diff_mask
loss = loss.mean([1, 2, 3])
primary_loss_scaler = 1.0
) * diff_mask
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
masked_noise_loss = masked_noise_loss.mean()
loss = (loss * primary_loss_scaler) + baseline_loss
loss = loss.mean()
loss = guidance_loss + masked_noise_loss
loss.backward()
@@ -280,8 +342,158 @@ def get_targeted_guidance_loss(
return loss
def get_targeted_guidance_loss_WIP(
noisy_latents: torch.Tensor,
conditional_embeds: 'PromptEmbeds',
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
# apply random offset to both latents
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
offset = offset * 0.1
conditional_latents = conditional_latents + offset
unconditional_latents = unconditional_latents + offset
# get random scale 0f 0.8 to 1.2
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
scale = scale * 0.4
scale = scale + 0.8
conditional_latents = conditional_latents * scale
unconditional_latents = unconditional_latents * scale
diff_mask = get_differential_mask(
conditional_latents,
unconditional_latents,
threshold=0.2,
gradient=True
)
# standardize inpute to std of 1
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
#
# # scale the latents to std of 1
# conditional_latents = conditional_latents / combo_std
# unconditional_latents = unconditional_latents / combo_std
unconditional_diff = (unconditional_latents - conditional_latents)
# get a -0.5 to 0.5 multiplier for the diff noise
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
# noise_multiplier = noise_multiplier - 0.5
noise_multiplier = 1.0
# unconditional_diff_noise = unconditional_diff * noise_multiplier
unconditional_diff_noise = unconditional_diff * noise_multiplier
# scale it to the timestep
unconditional_diff_noise = sd.add_noise(
torch.zeros_like(unconditional_latents),
unconditional_diff_noise,
timesteps
)
unconditional_diff_noise = unconditional_diff_noise * 0.2
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
baseline_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
target_noise = noise + unconditional_diff_noise
noisy_latents = sd.add_noise(
conditional_latents,
target_noise,
# noise,
timesteps
).detach()
# Disable the LoRA network so we can predict parent network knowledge without it
sd.network.is_active = False
sd.unet.eval()
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
baseline_prediction = sd.predict_noise(
latents=baseline_noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach().requires_grad_(False)
# turn the LoRA network back on.
sd.unet.train()
sd.network.is_active = True
sd.network.multiplier = network_weight_list
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
# masked_noise = noise * diff_mask
# pred_target = unmasked_noise + unconditional_diff_noise
# do our prediction with LoRA active on the scaled guidance latents
prediction = sd.predict_noise(
latents=noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
baselined_prediction = prediction - baseline_prediction
guidance_loss = torch.nn.functional.mse_loss(
baselined_prediction.float(),
# unconditional_diff_noise.float(),
unconditional_diff_noise.float(),
reduction="none"
)
guidance_loss = guidance_loss.mean([1, 2, 3])
guidance_loss = guidance_loss.mean()
# do the masked noise prediction
masked_noise_loss = torch.nn.functional.mse_loss(
prediction.float(),
target_noise.float(),
reduction="none"
) * diff_mask
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
masked_noise_loss = masked_noise_loss.mean()
loss = guidance_loss + masked_noise_loss
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
def get_guided_loss_polarity(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,