Tons of bug fixes and improvements to special training. Fixed slider training.

This commit is contained in:
Jaret Burkett
2023-12-09 16:38:10 -07:00
parent eaec2f5a52
commit eaa0fb6253
9 changed files with 639 additions and 74 deletions

View File

@@ -2,13 +2,15 @@ from collections import OrderedDict
from typing import Union, Literal, List
from diffusers import T2IAdapter
import torch.functional as F
from toolkit import train_tools
from toolkit.basic import value_map, adain, get_mean_std
from toolkit.config_modules import GuidanceConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import PromptEmbeds
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
apply_learnable_snr_gos, LearnableSNRGamma
@@ -35,6 +37,7 @@ class SDTrainer(BaseSDTrainProcess):
self.assistant_adapter: Union['T2IAdapter', None]
self.do_prior_prediction = False
self.do_long_prompts = False
self.do_guided_loss = False
if self.train_config.inverted_mask_prior:
self.do_prior_prediction = True
@@ -186,6 +189,33 @@ class SDTrainer(BaseSDTrainProcess):
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
loss = get_guidance_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
sd=self.sd,
**kwargs
)
return loss
def get_guided_loss_targeted_polarity(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
@@ -194,23 +224,28 @@ class SDTrainer(BaseSDTrainProcess):
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
unconditional_diff = unconditional_latents - conditional_latents
conditional_diff = conditional_latents - unconditional_latents
mean_latents = (conditional_latents + unconditional_latents) / 2.0
unconditional_diff = (unconditional_latents - mean_latents)
conditional_diff = (conditional_latents - mean_latents)
# we need to determine the amount of signal and noise that would be present at the current timestep
conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
target_noise = noise + unconditional_signal
# target_noise = noise + unconditional_signal
conditional_noisy_latents = self.sd.add_noise(
unconditional_latents + conditional_signal,
target_noise,
mean_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = self.sd.add_noise(
unconditional_latents,
mean_latents,
noise,
timesteps
).detach()
@@ -229,31 +264,210 @@ class SDTrainer(BaseSDTrainProcess):
**pred_kwargs # adapter residuals in here
).detach()
# turn the LoRA network back on.
self.sd.unet.train()
self.network.is_active = True
self.network.multiplier = network_weight_list
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
# since we are dividing the polarity from the middle out, we need to double our network
# weights on training since the convergent point will be at half network strength
negative_network_weights = [weight * -2.0 for weight in network_weight_list]
positive_network_weights = [weight * 2.0 for weight in network_weight_list]
cat_network_weight_list = positive_network_weights + negative_network_weights
# turn the LoRA network back on.
self.sd.unet.train()
self.network.is_active = True
self.network.multiplier = cat_network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = self.sd.predict_noise(
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
# remove baseline from our prediction to extract our differential prediction
prediction = prediction - baseline_prediction
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
loss = torch.nn.functional.mse_loss(
prediction.float(),
unconditional_signal.float(),
pred_pos = pred_pos - baseline_prediction
pred_neg = pred_neg - baseline_prediction
pred_loss = torch.nn.functional.mse_loss(
pred_pos.float(),
unconditional_diff.float(),
reduction="none"
)
loss = loss.mean([1, 2, 3])
pred_loss = pred_loss.mean([1, 2, 3])
loss = self.apply_snr(loss, timesteps)
pred_neg_loss = torch.nn.functional.mse_loss(
pred_neg.float(),
conditional_diff.float(),
reduction="none"
)
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
loss = (pred_loss + pred_neg_loss) / 2.0
# loss = self.apply_snr(loss, timesteps)
loss = loss.mean()
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
def get_guided_loss_masked_polarity(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
dtype = get_torch_dtype(self.train_config.dtype)
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents)
mean_latents = (conditional_latents + unconditional_latents) / 2.0
# unconditional_diff = (unconditional_latents - mean_latents)
# conditional_diff = (conditional_latents - mean_latents)
# we need to determine the amount of signal and noise that would be present at the current timestep
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
# make a differential mask
differential_mask = torch.abs(conditional_latents - unconditional_latents)
max_differential = \
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
differential_scaler = 1.0 / max_differential
differential_mask = differential_mask * differential_scaler
spread_point = 0.1
# adjust mask to amplify the differential at 0.1
differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point
# clip it
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
# target_noise = noise + unconditional_signal
conditional_noisy_latents = self.sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = self.sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
inverse_noisy_latents = self.sd.add_noise(
inverse_latents,
noise,
timesteps
).detach()
# Disable the LoRA network so we can predict parent network knowledge without it
self.network.is_active = False
self.sd.unet.eval()
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
# baseline_prediction = self.sd.predict_noise(
# latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
# conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
# timestep=timesteps,
# guidance_scale=1.0,
# **pred_kwargs # adapter residuals in here
# ).detach()
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
# since we are dividing the polarity from the middle out, we need to double our network
# weights on training since the convergent point will be at half network strength
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
cat_network_weight_list = positive_network_weights + negative_network_weights
# turn the LoRA network back on.
self.sd.unet.train()
self.network.is_active = True
self.network.multiplier = cat_network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = self.sd.predict_noise(
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
# create a loss to balance the mean to 0 between the two predictions
differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0
# pred_pos = pred_pos - baseline_prediction
# pred_neg = pred_neg - baseline_prediction
pred_loss = torch.nn.functional.mse_loss(
pred_pos.float(),
noise.float(),
reduction="none"
)
# apply mask
pred_loss = pred_loss * (1.0 + differential_mask)
pred_loss = pred_loss.mean([1, 2, 3])
pred_neg_loss = torch.nn.functional.mse_loss(
pred_neg.float(),
noise.float(),
reduction="none"
)
# apply inverse mask
pred_neg_loss = pred_neg_loss * (1.0 - differential_mask)
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
# make a loss to balance to losses of the pos and neg so they are equal
# differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss)
#
# differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss
#
# # add a multiplier to balancing losses to make them the top priority
# differential_mean_loss = differential_mean_loss
# remove the grads from the negative as it is only a balancing loss
# pred_neg_loss = pred_neg_loss.detach()
# loss = pred_loss + pred_neg_loss + differential_mean_loss
loss = pred_loss + pred_neg_loss
# loss = self.apply_snr(loss, timesteps)
loss = loss.mean()
loss.backward()
@@ -556,7 +770,7 @@ class SDTrainer(BaseSDTrainProcess):
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None:
if batch.unconditional_latents is not None or self.do_guided_loss:
# do guided loss
loss = self.get_guided_loss(
noisy_latents=noisy_latents,

View File

@@ -293,6 +293,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# will end in safetensors or pt
embed_files = [f for f in embed_items if f.endswith('.safetensors') or f.endswith('.pt')]
# check for critic files
critic_pattern = f"CRITIC_{self.job.name}_*"
critic_items = glob.glob(os.path.join(self.save_root, critic_pattern))
# Sort the lists by creation time if they are not empty
if safetensors_files:
safetensors_files.sort(key=os.path.getctime)
@@ -302,6 +306,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
directories.sort(key=os.path.getctime)
if embed_files:
embed_files.sort(key=os.path.getctime)
if critic_items:
critic_items.sort(key=os.path.getctime)
# Combine and sort the lists
combined_items = safetensors_files + directories + pt_files
@@ -313,8 +319,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else []
critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else []
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove
# remove all but the latest max_step_saves_to_keep
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
@@ -1041,8 +1048,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=self.network_config.conv,
conv_alpha=self.network_config.conv_alpha,
is_sdxl=self.model_config.is_xl,
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
is_v2=self.model_config.is_v2,
is_ssd=self.model_config.is_ssd,
dropout=self.network_config.dropout,
use_text_encoder_1=self.model_config.use_text_encoder_1,
use_text_encoder_2=self.model_config.use_text_encoder_2,

View File

@@ -371,7 +371,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# ger a random number of steps
timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,)
1, self.train_config.max_denoising_steps - 1, (1,)
).item()
# get noise
@@ -389,7 +389,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
assert not self.network.is_active
self.sd.unet.eval()
# pass the multiplier list to the network
self.network.multiplier = prompt_pair.multiplier_list
# double up since we are doing cfg
self.network.multiplier = prompt_pair.multiplier_list + prompt_pair.multiplier_list
denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
@@ -507,7 +508,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
for anchor_chunk, denoised_latent_chunk, anchor_target_noise_chunk in zip(
anchor_chunks, denoised_latent_chunks, anchor_target_noise_chunks
):
self.network.multiplier = anchor_chunk.multiplier_list
self.network.multiplier = anchor_chunk.multiplier_list + anchor_chunk.multiplier_list
anchor_pred_noise = get_noise_pred(
anchor_chunk.neg_prompt, anchor_chunk.prompt, 1, current_timestep, denoised_latent_chunk
@@ -582,7 +583,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
mask_multiplier_chunks,
unmasked_target_chunks
):
self.network.multiplier = prompt_pair_chunk.multiplier_list
self.network.multiplier = prompt_pair_chunk.multiplier_list + prompt_pair_chunk.multiplier_list
target_latents = get_noise_pred(
prompt_pair_chunk.positive_target,
prompt_pair_chunk.target_class,
@@ -611,6 +612,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
offset_neutral = neutral_latents_chunk
# offsets are already adjusted on a per-batch basis
offset_neutral += offset
offset_neutral = offset_neutral.detach().requires_grad_(False)
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")

View File

@@ -0,0 +1,20 @@
import argparse
import torch
import os
from diffusers import StableDiffusionPipeline
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# add project root to path
sys.path.append(PROJECT_ROOT)
SAMPLER_SCALES_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'samplers_scales')
parser = argparse.ArgumentParser(description='Process some images.')
add_arg = parser.add_argument
add_arg('--model', type=str, required=True, help='Path to model')
add_arg('--sampler', type=str, required=True, help='Name of sampler')
args = parser.parse_args()

View File

@@ -1,6 +1,6 @@
import os
import time
from typing import List, Optional, Literal, Union
from typing import List, Optional, Literal, Union, TYPE_CHECKING
import random
import torch
@@ -11,6 +11,8 @@ ImgExt = Literal['jpg', 'png', 'webp']
SaveFormat = Literal['safetensors', 'diffusers']
if TYPE_CHECKING:
from toolkit.guidance import GuidanceType
class SaveConfig:
def __init__(self, **kwargs):
@@ -400,6 +402,7 @@ class DatasetConfig:
if legacy_caption_type:
self.caption_ext = legacy_caption_type
self.caption_type = self.caption_ext
self.guidance_type: GuidanceType = kwargs.get('guidance_type', 'targeted_polarity')
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:

View File

@@ -1,5 +1,7 @@
import torch
from typing import Literal
from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion
@@ -8,13 +10,16 @@ from toolkit.train_tools import get_torch_dtype
GuidanceType = Literal["targeted", "polarity", "targeted_polarity"]
DIFFERENTIAL_SCALER = 0.2
# DIFFERENTIAL_SCALER = 0.25
def get_differential_mask(
conditional_latents: torch.Tensor,
unconditional_latents: torch.Tensor,
threshold: float = 0.2
threshold: float = 0.2,
gradient: bool = False,
):
# make a differential mask
differential_mask = torch.abs(conditional_latents - unconditional_latents)
@@ -23,12 +28,27 @@ def get_differential_mask(
differential_scaler = 1.0 / max_differential
differential_mask = differential_mask * differential_scaler
# make everything less than 0.2 be 0.0 and everything else be 1.0
differential_mask = torch.where(
differential_mask < threshold,
torch.zeros_like(differential_mask),
torch.ones_like(differential_mask)
)
if gradient:
# wew need to scale it to 0-1
# differential_mask = differential_mask - differential_mask.min()
# differential_mask = differential_mask / differential_mask.max()
# add 0.2 threshold to both sides and clip
differential_mask = value_map(
differential_mask,
differential_mask.min(),
differential_mask.max(),
0 - threshold,
1 + threshold
)
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
else:
# make everything less than 0.2 be 0.0 and everything else be 1.0
differential_mask = torch.where(
differential_mask < threshold,
torch.zeros_like(differential_mask),
torch.ones_like(differential_mask)
)
return differential_mask
@@ -47,7 +67,6 @@ def get_targeted_polarity_loss(
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
with torch.no_grad():
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
@@ -164,7 +183,7 @@ def get_targeted_polarity_loss(
return loss
# This targets only the positive differential
# targeted
def get_targeted_guidance_loss(
noisy_latents: torch.Tensor,
@@ -183,35 +202,71 @@ def get_targeted_guidance_loss(
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
unconditional_diff = (unconditional_latents - conditional_latents)
# apply random offset to both latents
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
offset = offset * 0.1
conditional_latents = conditional_latents + offset
unconditional_latents = unconditional_latents + offset
# get random scale 0f 0.8 to 1.2
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
scale = scale * 0.4
scale = scale + 0.8
conditional_latents = conditional_latents * scale
unconditional_latents = unconditional_latents * scale
diff_mask = get_differential_mask(
conditional_latents,
unconditional_latents,
threshold=0.1
threshold=0.2,
gradient=True
)
# this is a magic number I spent weeks deducing. It works and I have no idea why.
# unconditional_diff_noise = unconditional_diff * DIFFERENTIAL_SCALER
# standardize inpute to std of 1
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
#
# # scale the latents to std of 1
# conditional_latents = conditional_latents / combo_std
# unconditional_latents = unconditional_latents / combo_std
inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True)
noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True)
diff_noise_scaler = noise_abs_mean / inputs_abs_mean
unconditional_diff_noise = unconditional_diff * diff_noise_scaler
unconditional_diff = (unconditional_latents - conditional_latents)
# get a -0.5 to 0.5 multiplier for the diff noise
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
# noise_multiplier = noise_multiplier - 0.5
noise_multiplier = 1.0
# unconditional_diff_noise = unconditional_diff * noise_multiplier
unconditional_diff_noise = unconditional_diff * noise_multiplier
# scale it to the timestep
unconditional_diff_noise = sd.add_noise(
torch.zeros_like(unconditional_latents),
unconditional_diff_noise,
timesteps
)
unconditional_diff_noise = unconditional_diff_noise * 0.2
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
baseline_noisy_latents = sd.add_noise(
conditional_latents,
unconditional_latents,
noise,
timesteps
).detach()
target_noise = noise + unconditional_diff_noise
noisy_latents = sd.add_noise(
conditional_latents,
# noise + unconditional_diff_noise,
noise,
target_noise,
# noise,
timesteps
).detach()
# Disable the LoRA network so we can predict parent network knowledge without it
@@ -226,15 +281,20 @@ def get_targeted_guidance_loss(
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
).detach().requires_grad_(False)
# determine the error for the baseline prediction
baseline_prediction_error = baseline_prediction - noise
# turn the LoRA network back on.
sd.unet.train()
sd.network.is_active = True
sd.network.multiplier = network_weight_list
unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
masked_noise = noise * diff_mask
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
# masked_noise = noise * diff_mask
# pred_target = unmasked_noise + unconditional_diff_noise
# do our prediction with LoRA active on the scaled guidance latents
@@ -246,30 +306,32 @@ def get_targeted_guidance_loss(
**pred_kwargs # adapter residuals in here
)
prediction = prediction - unmasked_baseline_prediction
# prediction = prediction - baseline_prediction
baseline_loss = torch.nn.functional.mse_loss(
baseline_prediction.float(),
noise.float(),
baselined_prediction = prediction - baseline_prediction
guidance_loss = torch.nn.functional.mse_loss(
baselined_prediction.float(),
# unconditional_diff_noise.float(),
unconditional_diff_noise.float(),
reduction="none"
)
baseline_loss = baseline_loss * (1.0 - diff_mask)
baseline_loss = baseline_loss.mean([1, 2, 3])
guidance_loss = guidance_loss.mean([1, 2, 3])
# loss = torch.nn.functional.l1_loss(
loss = torch.nn.functional.mse_loss(
guidance_loss = guidance_loss.mean()
# do the masked noise prediction
masked_noise_loss = torch.nn.functional.mse_loss(
prediction.float(),
masked_noise.float(),
target_noise.float(),
reduction="none"
)
loss = loss * diff_mask
loss = loss.mean([1, 2, 3])
primary_loss_scaler = 1.0
) * diff_mask
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
masked_noise_loss = masked_noise_loss.mean()
loss = (loss * primary_loss_scaler) + baseline_loss
loss = loss.mean()
loss = guidance_loss + masked_noise_loss
loss.backward()
@@ -280,8 +342,158 @@ def get_targeted_guidance_loss(
return loss
def get_targeted_guidance_loss_WIP(
noisy_latents: torch.Tensor,
conditional_embeds: 'PromptEmbeds',
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
# apply random offset to both latents
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
offset = offset * 0.1
conditional_latents = conditional_latents + offset
unconditional_latents = unconditional_latents + offset
# get random scale 0f 0.8 to 1.2
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
scale = scale * 0.4
scale = scale + 0.8
conditional_latents = conditional_latents * scale
unconditional_latents = unconditional_latents * scale
diff_mask = get_differential_mask(
conditional_latents,
unconditional_latents,
threshold=0.2,
gradient=True
)
# standardize inpute to std of 1
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
#
# # scale the latents to std of 1
# conditional_latents = conditional_latents / combo_std
# unconditional_latents = unconditional_latents / combo_std
unconditional_diff = (unconditional_latents - conditional_latents)
# get a -0.5 to 0.5 multiplier for the diff noise
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
# noise_multiplier = noise_multiplier - 0.5
noise_multiplier = 1.0
# unconditional_diff_noise = unconditional_diff * noise_multiplier
unconditional_diff_noise = unconditional_diff * noise_multiplier
# scale it to the timestep
unconditional_diff_noise = sd.add_noise(
torch.zeros_like(unconditional_latents),
unconditional_diff_noise,
timesteps
)
unconditional_diff_noise = unconditional_diff_noise * 0.2
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
baseline_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
target_noise = noise + unconditional_diff_noise
noisy_latents = sd.add_noise(
conditional_latents,
target_noise,
# noise,
timesteps
).detach()
# Disable the LoRA network so we can predict parent network knowledge without it
sd.network.is_active = False
sd.unet.eval()
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
baseline_prediction = sd.predict_noise(
latents=baseline_noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach().requires_grad_(False)
# turn the LoRA network back on.
sd.unet.train()
sd.network.is_active = True
sd.network.multiplier = network_weight_list
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
# masked_noise = noise * diff_mask
# pred_target = unmasked_noise + unconditional_diff_noise
# do our prediction with LoRA active on the scaled guidance latents
prediction = sd.predict_noise(
latents=noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
baselined_prediction = prediction - baseline_prediction
guidance_loss = torch.nn.functional.mse_loss(
baselined_prediction.float(),
# unconditional_diff_noise.float(),
unconditional_diff_noise.float(),
reduction="none"
)
guidance_loss = guidance_loss.mean([1, 2, 3])
guidance_loss = guidance_loss.mean()
# do the masked noise prediction
masked_noise_loss = torch.nn.functional.mse_loss(
prediction.float(),
target_noise.float(),
reduction="none"
) * diff_mask
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
masked_noise_loss = masked_noise_loss.mean()
loss = guidance_loss + masked_noise_loss
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
def get_guided_loss_polarity(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,

View File

@@ -13,6 +13,7 @@ from toolkit.config_modules import NetworkConfig
from toolkit.lorm import extract_conv, extract_linear, count_parameters
from toolkit.metadata import add_model_hash_to_meta
from toolkit.paths import KEYMAPS_ROOT
from toolkit.saving import get_lora_keymap_from_model_keymap
if TYPE_CHECKING:
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
@@ -338,6 +339,7 @@ class ToolkitNetworkMixin:
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
is_ssd=False,
network_config: Optional[NetworkConfig] = None,
is_lorm=False,
**kwargs
@@ -348,6 +350,7 @@ class ToolkitNetworkMixin:
self._multiplier: float = 1.0
self.is_active: bool = False
self.is_sdxl = is_sdxl
self.is_ssd = is_ssd
self.is_v2 = is_v2
self.is_merged_in = False
self.is_lorm = is_lorm
@@ -357,14 +360,25 @@ class ToolkitNetworkMixin:
self.can_merge_in = not is_lorm
def get_keymap(self: Network):
if self.is_sdxl:
use_weight_mapping = False
if self.is_ssd:
keymap_tail = 'ssd'
use_weight_mapping = True
elif self.is_sdxl:
keymap_tail = 'sdxl'
elif self.is_v2:
keymap_tail = 'sd2'
else:
keymap_tail = 'sd1'
# todo double check this
use_weight_mapping = True
# load keymap
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
if use_weight_mapping:
keymap_name = f"stable_diffusion_{keymap_tail}.json"
keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name)
keymap = None
@@ -373,6 +387,10 @@ class ToolkitNetworkMixin:
with open(keymap_path, 'r') as f:
keymap = json.load(f)['ldm_diffusers_keymap']
if use_weight_mapping and keymap is not None:
# get keymap from weights
keymap = get_lora_keymap_from_model_keymap(keymap)
return keymap
def save_weights(

View File

@@ -206,6 +206,7 @@ def load_t2i_model(
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
def save_ip_adapter_from_diffusers(
combined_state_dict: 'OrderedDict',
output_file: str,
@@ -241,3 +242,58 @@ def load_ip_adapter_model(
return combined_state_dict
else:
return torch.load(path_to_file, map_location=device)
def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict':
lora_keymap = OrderedDict()
# see if we have dual text encoders " a key that starts with conditioner.embedders.1
has_dual_text_encoders = False
for key in model_keymap:
if key.startswith('conditioner.embedders.1'):
has_dual_text_encoders = True
break
# map through the keys and values
for key, value in model_keymap.items():
# ignore bias weights
if key.endswith('bias'):
continue
if key.endswith('.weight'):
# remove the .weight
key = key[:-7]
if value.endswith(".weight"):
# remove the .weight
value = value[:-7]
# unet for all
key = key.replace('model.diffusion_model', 'lora_unet')
if value.startswith('unet'):
value = f"lora_{value}"
# text encoder
if has_dual_text_encoders:
key = key.replace('conditioner.embedders.0', 'lora_te1')
key = key.replace('conditioner.embedders.1', 'lora_te2')
if value.startswith('te0') or value.startswith('te1'):
value = f"lora_{value}"
value.replace('lora_te1', 'lora_te2')
value.replace('lora_te0', 'lora_te1')
key = key.replace('cond_stage_model.transformer', 'lora_te')
if value.startswith('te_'):
value = f"lora_{value}"
# replace periods with underscores
key = key.replace('.', '_')
value = value.replace('.', '_')
# add all the weights
lora_keymap[f"{key}.lora_down.weight"] = f"{value}.lora_down.weight"
lora_keymap[f"{key}.lora_down.bias"] = f"{value}.lora_down.bias"
lora_keymap[f"{key}.lora_up.weight"] = f"{value}.lora_up.weight"
lora_keymap[f"{key}.lora_up.bias"] = f"{value}.lora_up.bias"
lora_keymap[f"{key}.alpha"] = f"{value}.alpha"
return lora_keymap

View File

@@ -35,7 +35,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLImg2ImgPipeline
StableDiffusionXLImg2ImgPipeline, LCMScheduler
import diffusers
from diffusers import \
AutoencoderKL, \
@@ -279,6 +279,20 @@ class StableDiffusion:
self.load_refiner()
self.is_loaded = True
def te_train(self):
if isinstance(self.text_encoder, list):
for te in self.text_encoder:
te.train()
else:
self.text_encoder.train()
def te_eval(self):
if isinstance(self.text_encoder, list):
for te in self.text_encoder:
te.eval()
else:
self.text_encoder.eval()
def load_refiner(self):
# for now, we are just going to rely on the TE from the base model
# which is TE2 for SDXL and TE for SD (no refiner currently)
@@ -721,6 +735,7 @@ class StableDiffusion:
add_time_ids=None,
conditional_embeddings: Union[PromptEmbeds, None] = None,
unconditional_embeddings: Union[PromptEmbeds, None] = None,
is_input_scaled=True,
**kwargs,
):
with torch.no_grad():
@@ -764,6 +779,8 @@ class StableDiffusion:
def scale_model_input(model_input, timestep_tensor):
if is_input_scaled:
return model_input
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
out_chunks = []
@@ -859,7 +876,7 @@ class StableDiffusion:
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs,
@@ -903,7 +920,7 @@ class StableDiffusion:
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
**kwargs,
@@ -924,6 +941,15 @@ class StableDiffusion:
return noise_pred
def step_scheduler(self, model_input, latent_input, timestep_tensor):
# // sometimes they are on the wrong device, no idea why
if isinstance(self.noise_scheduler, DDPMScheduler) or isinstance(self.noise_scheduler, LCMScheduler):
try:
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
except Exception as e:
pass
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
@@ -955,10 +981,12 @@ class StableDiffusion:
add_time_ids=None,
bleed_ratio: float = 0.5,
bleed_latents: torch.FloatTensor = None,
is_input_scaled=False,
**kwargs,
):
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
for timestep in tqdm(timesteps_to_run, leave=False):
timestep = timestep.unsqueeze_(0)
noise_pred = self.predict_noise(
@@ -967,6 +995,7 @@ class StableDiffusion:
timestep,
guidance_scale=guidance_scale,
add_time_ids=add_time_ids,
is_input_scaled=is_input_scaled,
**kwargs,
)
# some schedulers need to run separately, so do that. (euler for example)
@@ -977,6 +1006,9 @@ class StableDiffusion:
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
# only skip first scaling
is_input_scaled = False
# return latents_steps
return latents