mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Tons of bug fixes and improvements to special training. Fixed slider training.
This commit is contained in:
@@ -2,13 +2,15 @@ from collections import OrderedDict
|
|||||||
from typing import Union, Literal, List
|
from typing import Union, Literal, List
|
||||||
from diffusers import T2IAdapter
|
from diffusers import T2IAdapter
|
||||||
|
|
||||||
|
import torch.functional as F
|
||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
from toolkit.basic import value_map, adain, get_mean_std
|
from toolkit.basic import value_map, adain, get_mean_std
|
||||||
from toolkit.config_modules import GuidanceConfig
|
from toolkit.config_modules import GuidanceConfig
|
||||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||||
|
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
|
||||||
from toolkit.image_utils import show_tensors, show_latents
|
from toolkit.image_utils import show_tensors, show_latents
|
||||||
from toolkit.ip_adapter import IPAdapter
|
from toolkit.ip_adapter import IPAdapter
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
|
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
|
||||||
apply_learnable_snr_gos, LearnableSNRGamma
|
apply_learnable_snr_gos, LearnableSNRGamma
|
||||||
@@ -35,6 +37,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.assistant_adapter: Union['T2IAdapter', None]
|
self.assistant_adapter: Union['T2IAdapter', None]
|
||||||
self.do_prior_prediction = False
|
self.do_prior_prediction = False
|
||||||
self.do_long_prompts = False
|
self.do_long_prompts = False
|
||||||
|
self.do_guided_loss = False
|
||||||
if self.train_config.inverted_mask_prior:
|
if self.train_config.inverted_mask_prior:
|
||||||
self.do_prior_prediction = True
|
self.do_prior_prediction = True
|
||||||
|
|
||||||
@@ -186,6 +189,33 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
batch: 'DataLoaderBatchDTO',
|
batch: 'DataLoaderBatchDTO',
|
||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
**kwargs
|
**kwargs
|
||||||
|
):
|
||||||
|
loss = get_guidance_loss(
|
||||||
|
noisy_latents=noisy_latents,
|
||||||
|
conditional_embeds=conditional_embeds,
|
||||||
|
match_adapter_assist=match_adapter_assist,
|
||||||
|
network_weight_list=network_weight_list,
|
||||||
|
timesteps=timesteps,
|
||||||
|
pred_kwargs=pred_kwargs,
|
||||||
|
batch=batch,
|
||||||
|
noise=noise,
|
||||||
|
sd=self.sd,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def get_guided_loss_targeted_polarity(
|
||||||
|
self,
|
||||||
|
noisy_latents: torch.Tensor,
|
||||||
|
conditional_embeds: PromptEmbeds,
|
||||||
|
match_adapter_assist: bool,
|
||||||
|
network_weight_list: list,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
pred_kwargs: dict,
|
||||||
|
batch: 'DataLoaderBatchDTO',
|
||||||
|
noise: torch.Tensor,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Perform targeted guidance (working title)
|
# Perform targeted guidance (working title)
|
||||||
@@ -194,23 +224,28 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
||||||
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
||||||
|
|
||||||
unconditional_diff = unconditional_latents - conditional_latents
|
mean_latents = (conditional_latents + unconditional_latents) / 2.0
|
||||||
conditional_diff = conditional_latents - unconditional_latents
|
|
||||||
|
unconditional_diff = (unconditional_latents - mean_latents)
|
||||||
|
conditional_diff = (conditional_latents - mean_latents)
|
||||||
|
|
||||||
# we need to determine the amount of signal and noise that would be present at the current timestep
|
# we need to determine the amount of signal and noise that would be present at the current timestep
|
||||||
conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
||||||
unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
||||||
|
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
|
||||||
|
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
|
||||||
|
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
|
||||||
|
|
||||||
target_noise = noise + unconditional_signal
|
# target_noise = noise + unconditional_signal
|
||||||
|
|
||||||
conditional_noisy_latents = self.sd.add_noise(
|
conditional_noisy_latents = self.sd.add_noise(
|
||||||
unconditional_latents + conditional_signal,
|
mean_latents,
|
||||||
target_noise,
|
noise,
|
||||||
timesteps
|
timesteps
|
||||||
).detach()
|
).detach()
|
||||||
|
|
||||||
unconditional_noisy_latents = self.sd.add_noise(
|
unconditional_noisy_latents = self.sd.add_noise(
|
||||||
unconditional_latents,
|
mean_latents,
|
||||||
noise,
|
noise,
|
||||||
timesteps
|
timesteps
|
||||||
).detach()
|
).detach()
|
||||||
@@ -229,31 +264,210 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
**pred_kwargs # adapter residuals in here
|
**pred_kwargs # adapter residuals in here
|
||||||
).detach()
|
).detach()
|
||||||
|
|
||||||
# turn the LoRA network back on.
|
# double up everything to run it through all at once
|
||||||
self.sd.unet.train()
|
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||||
self.network.is_active = True
|
cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0)
|
||||||
self.network.multiplier = network_weight_list
|
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||||
|
|
||||||
|
# since we are dividing the polarity from the middle out, we need to double our network
|
||||||
|
# weights on training since the convergent point will be at half network strength
|
||||||
|
|
||||||
|
negative_network_weights = [weight * -2.0 for weight in network_weight_list]
|
||||||
|
positive_network_weights = [weight * 2.0 for weight in network_weight_list]
|
||||||
|
cat_network_weight_list = positive_network_weights + negative_network_weights
|
||||||
|
|
||||||
|
# turn the LoRA network back on.
|
||||||
|
self.sd.unet.train()
|
||||||
|
self.network.is_active = True
|
||||||
|
|
||||||
|
self.network.multiplier = cat_network_weight_list
|
||||||
|
|
||||||
# do our prediction with LoRA active on the scaled guidance latents
|
# do our prediction with LoRA active on the scaled guidance latents
|
||||||
prediction = self.sd.predict_noise(
|
prediction = self.sd.predict_noise(
|
||||||
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||||
timestep=timesteps,
|
timestep=cat_timesteps,
|
||||||
guidance_scale=1.0,
|
guidance_scale=1.0,
|
||||||
**pred_kwargs # adapter residuals in here
|
**pred_kwargs # adapter residuals in here
|
||||||
)
|
)
|
||||||
|
|
||||||
# remove baseline from our prediction to extract our differential prediction
|
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
||||||
prediction = prediction - baseline_prediction
|
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(
|
pred_pos = pred_pos - baseline_prediction
|
||||||
prediction.float(),
|
pred_neg = pred_neg - baseline_prediction
|
||||||
unconditional_signal.float(),
|
|
||||||
|
pred_loss = torch.nn.functional.mse_loss(
|
||||||
|
pred_pos.float(),
|
||||||
|
unconditional_diff.float(),
|
||||||
reduction="none"
|
reduction="none"
|
||||||
)
|
)
|
||||||
loss = loss.mean([1, 2, 3])
|
pred_loss = pred_loss.mean([1, 2, 3])
|
||||||
|
|
||||||
loss = self.apply_snr(loss, timesteps)
|
pred_neg_loss = torch.nn.functional.mse_loss(
|
||||||
|
pred_neg.float(),
|
||||||
|
conditional_diff.float(),
|
||||||
|
reduction="none"
|
||||||
|
)
|
||||||
|
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
loss = (pred_loss + pred_neg_loss) / 2.0
|
||||||
|
|
||||||
|
# loss = self.apply_snr(loss, timesteps)
|
||||||
|
loss = loss.mean()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# detach it so parent class can run backward on no grads without throwing error
|
||||||
|
loss = loss.detach()
|
||||||
|
loss.requires_grad_(True)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def get_guided_loss_masked_polarity(
|
||||||
|
self,
|
||||||
|
noisy_latents: torch.Tensor,
|
||||||
|
conditional_embeds: PromptEmbeds,
|
||||||
|
match_adapter_assist: bool,
|
||||||
|
network_weight_list: list,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
pred_kwargs: dict,
|
||||||
|
batch: 'DataLoaderBatchDTO',
|
||||||
|
noise: torch.Tensor,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
# Perform targeted guidance (working title)
|
||||||
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
|
||||||
|
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
||||||
|
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
||||||
|
inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents)
|
||||||
|
|
||||||
|
mean_latents = (conditional_latents + unconditional_latents) / 2.0
|
||||||
|
|
||||||
|
# unconditional_diff = (unconditional_latents - mean_latents)
|
||||||
|
# conditional_diff = (conditional_latents - mean_latents)
|
||||||
|
|
||||||
|
# we need to determine the amount of signal and noise that would be present at the current timestep
|
||||||
|
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
||||||
|
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
||||||
|
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
|
||||||
|
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
|
||||||
|
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
|
||||||
|
|
||||||
|
# make a differential mask
|
||||||
|
differential_mask = torch.abs(conditional_latents - unconditional_latents)
|
||||||
|
max_differential = \
|
||||||
|
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
||||||
|
differential_scaler = 1.0 / max_differential
|
||||||
|
differential_mask = differential_mask * differential_scaler
|
||||||
|
spread_point = 0.1
|
||||||
|
# adjust mask to amplify the differential at 0.1
|
||||||
|
differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point
|
||||||
|
# clip it
|
||||||
|
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
|
||||||
|
|
||||||
|
# target_noise = noise + unconditional_signal
|
||||||
|
|
||||||
|
conditional_noisy_latents = self.sd.add_noise(
|
||||||
|
conditional_latents,
|
||||||
|
noise,
|
||||||
|
timesteps
|
||||||
|
).detach()
|
||||||
|
|
||||||
|
unconditional_noisy_latents = self.sd.add_noise(
|
||||||
|
unconditional_latents,
|
||||||
|
noise,
|
||||||
|
timesteps
|
||||||
|
).detach()
|
||||||
|
|
||||||
|
inverse_noisy_latents = self.sd.add_noise(
|
||||||
|
inverse_latents,
|
||||||
|
noise,
|
||||||
|
timesteps
|
||||||
|
).detach()
|
||||||
|
|
||||||
|
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||||
|
self.network.is_active = False
|
||||||
|
self.sd.unet.eval()
|
||||||
|
|
||||||
|
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||||
|
# This acts as our control to preserve the unaltered parts of the image.
|
||||||
|
# baseline_prediction = self.sd.predict_noise(
|
||||||
|
# latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||||
|
# conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||||
|
# timestep=timesteps,
|
||||||
|
# guidance_scale=1.0,
|
||||||
|
# **pred_kwargs # adapter residuals in here
|
||||||
|
# ).detach()
|
||||||
|
|
||||||
|
# double up everything to run it through all at once
|
||||||
|
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||||
|
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
|
||||||
|
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||||
|
|
||||||
|
# since we are dividing the polarity from the middle out, we need to double our network
|
||||||
|
# weights on training since the convergent point will be at half network strength
|
||||||
|
|
||||||
|
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
|
||||||
|
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
|
||||||
|
cat_network_weight_list = positive_network_weights + negative_network_weights
|
||||||
|
|
||||||
|
# turn the LoRA network back on.
|
||||||
|
self.sd.unet.train()
|
||||||
|
self.network.is_active = True
|
||||||
|
|
||||||
|
self.network.multiplier = cat_network_weight_list
|
||||||
|
|
||||||
|
# do our prediction with LoRA active on the scaled guidance latents
|
||||||
|
prediction = self.sd.predict_noise(
|
||||||
|
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||||
|
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||||
|
timestep=cat_timesteps,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
**pred_kwargs # adapter residuals in here
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
||||||
|
|
||||||
|
# create a loss to balance the mean to 0 between the two predictions
|
||||||
|
differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0
|
||||||
|
|
||||||
|
# pred_pos = pred_pos - baseline_prediction
|
||||||
|
# pred_neg = pred_neg - baseline_prediction
|
||||||
|
|
||||||
|
pred_loss = torch.nn.functional.mse_loss(
|
||||||
|
pred_pos.float(),
|
||||||
|
noise.float(),
|
||||||
|
reduction="none"
|
||||||
|
)
|
||||||
|
# apply mask
|
||||||
|
pred_loss = pred_loss * (1.0 + differential_mask)
|
||||||
|
pred_loss = pred_loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
pred_neg_loss = torch.nn.functional.mse_loss(
|
||||||
|
pred_neg.float(),
|
||||||
|
noise.float(),
|
||||||
|
reduction="none"
|
||||||
|
)
|
||||||
|
# apply inverse mask
|
||||||
|
pred_neg_loss = pred_neg_loss * (1.0 - differential_mask)
|
||||||
|
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
# make a loss to balance to losses of the pos and neg so they are equal
|
||||||
|
# differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss)
|
||||||
|
#
|
||||||
|
# differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss
|
||||||
|
#
|
||||||
|
# # add a multiplier to balancing losses to make them the top priority
|
||||||
|
# differential_mean_loss = differential_mean_loss
|
||||||
|
|
||||||
|
# remove the grads from the negative as it is only a balancing loss
|
||||||
|
# pred_neg_loss = pred_neg_loss.detach()
|
||||||
|
|
||||||
|
# loss = pred_loss + pred_neg_loss + differential_mean_loss
|
||||||
|
loss = pred_loss + pred_neg_loss
|
||||||
|
|
||||||
|
# loss = self.apply_snr(loss, timesteps)
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
@@ -556,7 +770,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
self.before_unet_predict()
|
self.before_unet_predict()
|
||||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||||
if batch.unconditional_latents is not None:
|
if batch.unconditional_latents is not None or self.do_guided_loss:
|
||||||
# do guided loss
|
# do guided loss
|
||||||
loss = self.get_guided_loss(
|
loss = self.get_guided_loss(
|
||||||
noisy_latents=noisy_latents,
|
noisy_latents=noisy_latents,
|
||||||
|
|||||||
@@ -293,6 +293,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# will end in safetensors or pt
|
# will end in safetensors or pt
|
||||||
embed_files = [f for f in embed_items if f.endswith('.safetensors') or f.endswith('.pt')]
|
embed_files = [f for f in embed_items if f.endswith('.safetensors') or f.endswith('.pt')]
|
||||||
|
|
||||||
|
# check for critic files
|
||||||
|
critic_pattern = f"CRITIC_{self.job.name}_*"
|
||||||
|
critic_items = glob.glob(os.path.join(self.save_root, critic_pattern))
|
||||||
|
|
||||||
# Sort the lists by creation time if they are not empty
|
# Sort the lists by creation time if they are not empty
|
||||||
if safetensors_files:
|
if safetensors_files:
|
||||||
safetensors_files.sort(key=os.path.getctime)
|
safetensors_files.sort(key=os.path.getctime)
|
||||||
@@ -302,6 +306,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
directories.sort(key=os.path.getctime)
|
directories.sort(key=os.path.getctime)
|
||||||
if embed_files:
|
if embed_files:
|
||||||
embed_files.sort(key=os.path.getctime)
|
embed_files.sort(key=os.path.getctime)
|
||||||
|
if critic_items:
|
||||||
|
critic_items.sort(key=os.path.getctime)
|
||||||
|
|
||||||
# Combine and sort the lists
|
# Combine and sort the lists
|
||||||
combined_items = safetensors_files + directories + pt_files
|
combined_items = safetensors_files + directories + pt_files
|
||||||
@@ -313,8 +319,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
|
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
|
||||||
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
|
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
|
||||||
embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else []
|
embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else []
|
||||||
|
critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else []
|
||||||
|
|
||||||
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove
|
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove
|
||||||
|
|
||||||
# remove all but the latest max_step_saves_to_keep
|
# remove all but the latest max_step_saves_to_keep
|
||||||
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||||
@@ -1041,8 +1048,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
train_text_encoder=self.train_config.train_text_encoder,
|
train_text_encoder=self.train_config.train_text_encoder,
|
||||||
conv_lora_dim=self.network_config.conv,
|
conv_lora_dim=self.network_config.conv,
|
||||||
conv_alpha=self.network_config.conv_alpha,
|
conv_alpha=self.network_config.conv_alpha,
|
||||||
is_sdxl=self.model_config.is_xl,
|
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
|
||||||
is_v2=self.model_config.is_v2,
|
is_v2=self.model_config.is_v2,
|
||||||
|
is_ssd=self.model_config.is_ssd,
|
||||||
dropout=self.network_config.dropout,
|
dropout=self.network_config.dropout,
|
||||||
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
||||||
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
||||||
|
|||||||
@@ -371,7 +371,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
|
|
||||||
# ger a random number of steps
|
# ger a random number of steps
|
||||||
timesteps_to = torch.randint(
|
timesteps_to = torch.randint(
|
||||||
1, self.train_config.max_denoising_steps, (1,)
|
1, self.train_config.max_denoising_steps - 1, (1,)
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
# get noise
|
# get noise
|
||||||
@@ -389,7 +389,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
assert not self.network.is_active
|
assert not self.network.is_active
|
||||||
self.sd.unet.eval()
|
self.sd.unet.eval()
|
||||||
# pass the multiplier list to the network
|
# pass the multiplier list to the network
|
||||||
self.network.multiplier = prompt_pair.multiplier_list
|
# double up since we are doing cfg
|
||||||
|
self.network.multiplier = prompt_pair.multiplier_list + prompt_pair.multiplier_list
|
||||||
denoised_latents = self.sd.diffuse_some_steps(
|
denoised_latents = self.sd.diffuse_some_steps(
|
||||||
latents, # pass simple noise latents
|
latents, # pass simple noise latents
|
||||||
train_tools.concat_prompt_embeddings(
|
train_tools.concat_prompt_embeddings(
|
||||||
@@ -507,7 +508,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
for anchor_chunk, denoised_latent_chunk, anchor_target_noise_chunk in zip(
|
for anchor_chunk, denoised_latent_chunk, anchor_target_noise_chunk in zip(
|
||||||
anchor_chunks, denoised_latent_chunks, anchor_target_noise_chunks
|
anchor_chunks, denoised_latent_chunks, anchor_target_noise_chunks
|
||||||
):
|
):
|
||||||
self.network.multiplier = anchor_chunk.multiplier_list
|
self.network.multiplier = anchor_chunk.multiplier_list + anchor_chunk.multiplier_list
|
||||||
|
|
||||||
anchor_pred_noise = get_noise_pred(
|
anchor_pred_noise = get_noise_pred(
|
||||||
anchor_chunk.neg_prompt, anchor_chunk.prompt, 1, current_timestep, denoised_latent_chunk
|
anchor_chunk.neg_prompt, anchor_chunk.prompt, 1, current_timestep, denoised_latent_chunk
|
||||||
@@ -582,7 +583,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
mask_multiplier_chunks,
|
mask_multiplier_chunks,
|
||||||
unmasked_target_chunks
|
unmasked_target_chunks
|
||||||
):
|
):
|
||||||
self.network.multiplier = prompt_pair_chunk.multiplier_list
|
self.network.multiplier = prompt_pair_chunk.multiplier_list + prompt_pair_chunk.multiplier_list
|
||||||
target_latents = get_noise_pred(
|
target_latents = get_noise_pred(
|
||||||
prompt_pair_chunk.positive_target,
|
prompt_pair_chunk.positive_target,
|
||||||
prompt_pair_chunk.target_class,
|
prompt_pair_chunk.target_class,
|
||||||
@@ -611,6 +612,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
offset_neutral = neutral_latents_chunk
|
offset_neutral = neutral_latents_chunk
|
||||||
# offsets are already adjusted on a per-batch basis
|
# offsets are already adjusted on a per-batch basis
|
||||||
offset_neutral += offset
|
offset_neutral += offset
|
||||||
|
offset_neutral = offset_neutral.detach().requires_grad_(False)
|
||||||
|
|
||||||
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
|
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
|
||||||
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
|
||||||
|
|||||||
20
scripts/generate_sampler_step_scales.py
Normal file
20
scripts/generate_sampler_step_scales.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
import sys
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
# add project root to path
|
||||||
|
sys.path.append(PROJECT_ROOT)
|
||||||
|
|
||||||
|
SAMPLER_SCALES_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'samplers_scales')
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Process some images.')
|
||||||
|
add_arg = parser.add_argument
|
||||||
|
add_arg('--model', type=str, required=True, help='Path to model')
|
||||||
|
add_arg('--sampler', type=str, required=True, help='Name of sampler')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Literal, Union
|
from typing import List, Optional, Literal, Union, TYPE_CHECKING
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -11,6 +11,8 @@ ImgExt = Literal['jpg', 'png', 'webp']
|
|||||||
|
|
||||||
SaveFormat = Literal['safetensors', 'diffusers']
|
SaveFormat = Literal['safetensors', 'diffusers']
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.guidance import GuidanceType
|
||||||
|
|
||||||
class SaveConfig:
|
class SaveConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@@ -400,6 +402,7 @@ class DatasetConfig:
|
|||||||
if legacy_caption_type:
|
if legacy_caption_type:
|
||||||
self.caption_ext = legacy_caption_type
|
self.caption_ext = legacy_caption_type
|
||||||
self.caption_type = self.caption_ext
|
self.caption_type = self.caption_ext
|
||||||
|
self.guidance_type: GuidanceType = kwargs.get('guidance_type', 'targeted_polarity')
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
from toolkit.basic import value_map
|
||||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
@@ -8,13 +10,16 @@ from toolkit.train_tools import get_torch_dtype
|
|||||||
GuidanceType = Literal["targeted", "polarity", "targeted_polarity"]
|
GuidanceType = Literal["targeted", "polarity", "targeted_polarity"]
|
||||||
|
|
||||||
DIFFERENTIAL_SCALER = 0.2
|
DIFFERENTIAL_SCALER = 0.2
|
||||||
|
|
||||||
|
|
||||||
# DIFFERENTIAL_SCALER = 0.25
|
# DIFFERENTIAL_SCALER = 0.25
|
||||||
|
|
||||||
|
|
||||||
def get_differential_mask(
|
def get_differential_mask(
|
||||||
conditional_latents: torch.Tensor,
|
conditional_latents: torch.Tensor,
|
||||||
unconditional_latents: torch.Tensor,
|
unconditional_latents: torch.Tensor,
|
||||||
threshold: float = 0.2
|
threshold: float = 0.2,
|
||||||
|
gradient: bool = False,
|
||||||
):
|
):
|
||||||
# make a differential mask
|
# make a differential mask
|
||||||
differential_mask = torch.abs(conditional_latents - unconditional_latents)
|
differential_mask = torch.abs(conditional_latents - unconditional_latents)
|
||||||
@@ -23,12 +28,27 @@ def get_differential_mask(
|
|||||||
differential_scaler = 1.0 / max_differential
|
differential_scaler = 1.0 / max_differential
|
||||||
differential_mask = differential_mask * differential_scaler
|
differential_mask = differential_mask * differential_scaler
|
||||||
|
|
||||||
# make everything less than 0.2 be 0.0 and everything else be 1.0
|
if gradient:
|
||||||
differential_mask = torch.where(
|
# wew need to scale it to 0-1
|
||||||
differential_mask < threshold,
|
# differential_mask = differential_mask - differential_mask.min()
|
||||||
torch.zeros_like(differential_mask),
|
# differential_mask = differential_mask / differential_mask.max()
|
||||||
torch.ones_like(differential_mask)
|
# add 0.2 threshold to both sides and clip
|
||||||
)
|
differential_mask = value_map(
|
||||||
|
differential_mask,
|
||||||
|
differential_mask.min(),
|
||||||
|
differential_mask.max(),
|
||||||
|
0 - threshold,
|
||||||
|
1 + threshold
|
||||||
|
)
|
||||||
|
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
|
||||||
|
else:
|
||||||
|
|
||||||
|
# make everything less than 0.2 be 0.0 and everything else be 1.0
|
||||||
|
differential_mask = torch.where(
|
||||||
|
differential_mask < threshold,
|
||||||
|
torch.zeros_like(differential_mask),
|
||||||
|
torch.ones_like(differential_mask)
|
||||||
|
)
|
||||||
return differential_mask
|
return differential_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +67,6 @@ def get_targeted_polarity_loss(
|
|||||||
dtype = get_torch_dtype(sd.torch_dtype)
|
dtype = get_torch_dtype(sd.torch_dtype)
|
||||||
device = sd.device_torch
|
device = sd.device_torch
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||||
|
|
||||||
@@ -164,7 +183,7 @@ def get_targeted_polarity_loss(
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
# This targets only the positive differential
|
|
||||||
# targeted
|
# targeted
|
||||||
def get_targeted_guidance_loss(
|
def get_targeted_guidance_loss(
|
||||||
noisy_latents: torch.Tensor,
|
noisy_latents: torch.Tensor,
|
||||||
@@ -183,35 +202,71 @@ def get_targeted_guidance_loss(
|
|||||||
dtype = get_torch_dtype(sd.torch_dtype)
|
dtype = get_torch_dtype(sd.torch_dtype)
|
||||||
device = sd.device_torch
|
device = sd.device_torch
|
||||||
|
|
||||||
|
|
||||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||||
|
|
||||||
unconditional_diff = (unconditional_latents - conditional_latents)
|
# apply random offset to both latents
|
||||||
|
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||||
|
offset = offset * 0.1
|
||||||
|
conditional_latents = conditional_latents + offset
|
||||||
|
unconditional_latents = unconditional_latents + offset
|
||||||
|
|
||||||
|
# get random scale 0f 0.8 to 1.2
|
||||||
|
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||||
|
scale = scale * 0.4
|
||||||
|
scale = scale + 0.8
|
||||||
|
conditional_latents = conditional_latents * scale
|
||||||
|
unconditional_latents = unconditional_latents * scale
|
||||||
|
|
||||||
diff_mask = get_differential_mask(
|
diff_mask = get_differential_mask(
|
||||||
conditional_latents,
|
conditional_latents,
|
||||||
unconditional_latents,
|
unconditional_latents,
|
||||||
threshold=0.1
|
threshold=0.2,
|
||||||
|
gradient=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# this is a magic number I spent weeks deducing. It works and I have no idea why.
|
# standardize inpute to std of 1
|
||||||
# unconditional_diff_noise = unconditional_diff * DIFFERENTIAL_SCALER
|
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
|
||||||
|
#
|
||||||
|
# # scale the latents to std of 1
|
||||||
|
# conditional_latents = conditional_latents / combo_std
|
||||||
|
# unconditional_latents = unconditional_latents / combo_std
|
||||||
|
|
||||||
inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True)
|
unconditional_diff = (unconditional_latents - conditional_latents)
|
||||||
noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True)
|
|
||||||
diff_noise_scaler = noise_abs_mean / inputs_abs_mean
|
|
||||||
unconditional_diff_noise = unconditional_diff * diff_noise_scaler
|
|
||||||
|
# get a -0.5 to 0.5 multiplier for the diff noise
|
||||||
|
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||||
|
# noise_multiplier = noise_multiplier - 0.5
|
||||||
|
noise_multiplier = 1.0
|
||||||
|
|
||||||
|
# unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||||
|
unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||||
|
|
||||||
|
# scale it to the timestep
|
||||||
|
unconditional_diff_noise = sd.add_noise(
|
||||||
|
torch.zeros_like(unconditional_latents),
|
||||||
|
unconditional_diff_noise,
|
||||||
|
timesteps
|
||||||
|
)
|
||||||
|
|
||||||
|
unconditional_diff_noise = unconditional_diff_noise * 0.2
|
||||||
|
|
||||||
|
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
|
||||||
|
|
||||||
baseline_noisy_latents = sd.add_noise(
|
baseline_noisy_latents = sd.add_noise(
|
||||||
conditional_latents,
|
unconditional_latents,
|
||||||
noise,
|
noise,
|
||||||
timesteps
|
timesteps
|
||||||
).detach()
|
).detach()
|
||||||
|
|
||||||
|
target_noise = noise + unconditional_diff_noise
|
||||||
noisy_latents = sd.add_noise(
|
noisy_latents = sd.add_noise(
|
||||||
conditional_latents,
|
conditional_latents,
|
||||||
# noise + unconditional_diff_noise,
|
target_noise,
|
||||||
noise,
|
# noise,
|
||||||
timesteps
|
timesteps
|
||||||
).detach()
|
).detach()
|
||||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||||
@@ -226,15 +281,20 @@ def get_targeted_guidance_loss(
|
|||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
guidance_scale=1.0,
|
guidance_scale=1.0,
|
||||||
**pred_kwargs # adapter residuals in here
|
**pred_kwargs # adapter residuals in here
|
||||||
).detach()
|
).detach().requires_grad_(False)
|
||||||
|
|
||||||
|
# determine the error for the baseline prediction
|
||||||
|
baseline_prediction_error = baseline_prediction - noise
|
||||||
|
|
||||||
|
|
||||||
# turn the LoRA network back on.
|
# turn the LoRA network back on.
|
||||||
sd.unet.train()
|
sd.unet.train()
|
||||||
sd.network.is_active = True
|
sd.network.is_active = True
|
||||||
|
|
||||||
sd.network.multiplier = network_weight_list
|
sd.network.multiplier = network_weight_list
|
||||||
|
|
||||||
unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
|
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
|
||||||
masked_noise = noise * diff_mask
|
# masked_noise = noise * diff_mask
|
||||||
# pred_target = unmasked_noise + unconditional_diff_noise
|
# pred_target = unmasked_noise + unconditional_diff_noise
|
||||||
|
|
||||||
# do our prediction with LoRA active on the scaled guidance latents
|
# do our prediction with LoRA active on the scaled guidance latents
|
||||||
@@ -246,30 +306,32 @@ def get_targeted_guidance_loss(
|
|||||||
**pred_kwargs # adapter residuals in here
|
**pred_kwargs # adapter residuals in here
|
||||||
)
|
)
|
||||||
|
|
||||||
prediction = prediction - unmasked_baseline_prediction
|
|
||||||
# prediction = prediction - baseline_prediction
|
|
||||||
|
|
||||||
baseline_loss = torch.nn.functional.mse_loss(
|
|
||||||
baseline_prediction.float(),
|
baselined_prediction = prediction - baseline_prediction
|
||||||
noise.float(),
|
|
||||||
|
guidance_loss = torch.nn.functional.mse_loss(
|
||||||
|
baselined_prediction.float(),
|
||||||
|
# unconditional_diff_noise.float(),
|
||||||
|
unconditional_diff_noise.float(),
|
||||||
reduction="none"
|
reduction="none"
|
||||||
)
|
)
|
||||||
baseline_loss = baseline_loss * (1.0 - diff_mask)
|
guidance_loss = guidance_loss.mean([1, 2, 3])
|
||||||
baseline_loss = baseline_loss.mean([1, 2, 3])
|
|
||||||
|
|
||||||
# loss = torch.nn.functional.l1_loss(
|
guidance_loss = guidance_loss.mean()
|
||||||
loss = torch.nn.functional.mse_loss(
|
|
||||||
|
|
||||||
|
# do the masked noise prediction
|
||||||
|
masked_noise_loss = torch.nn.functional.mse_loss(
|
||||||
prediction.float(),
|
prediction.float(),
|
||||||
masked_noise.float(),
|
target_noise.float(),
|
||||||
reduction="none"
|
reduction="none"
|
||||||
)
|
) * diff_mask
|
||||||
loss = loss * diff_mask
|
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
|
||||||
loss = loss.mean([1, 2, 3])
|
masked_noise_loss = masked_noise_loss.mean()
|
||||||
primary_loss_scaler = 1.0
|
|
||||||
|
|
||||||
loss = (loss * primary_loss_scaler) + baseline_loss
|
|
||||||
|
|
||||||
loss = loss.mean()
|
loss = guidance_loss + masked_noise_loss
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
@@ -280,8 +342,158 @@ def get_targeted_guidance_loss(
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def get_targeted_guidance_loss_WIP(
|
||||||
|
noisy_latents: torch.Tensor,
|
||||||
|
conditional_embeds: 'PromptEmbeds',
|
||||||
|
match_adapter_assist: bool,
|
||||||
|
network_weight_list: list,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
pred_kwargs: dict,
|
||||||
|
batch: 'DataLoaderBatchDTO',
|
||||||
|
noise: torch.Tensor,
|
||||||
|
sd: 'StableDiffusion',
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
# Perform targeted guidance (working title)
|
||||||
|
dtype = get_torch_dtype(sd.torch_dtype)
|
||||||
|
device = sd.device_torch
|
||||||
|
|
||||||
|
|
||||||
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||||
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||||
|
|
||||||
|
# apply random offset to both latents
|
||||||
|
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||||
|
offset = offset * 0.1
|
||||||
|
conditional_latents = conditional_latents + offset
|
||||||
|
unconditional_latents = unconditional_latents + offset
|
||||||
|
|
||||||
|
# get random scale 0f 0.8 to 1.2
|
||||||
|
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||||
|
scale = scale * 0.4
|
||||||
|
scale = scale + 0.8
|
||||||
|
conditional_latents = conditional_latents * scale
|
||||||
|
unconditional_latents = unconditional_latents * scale
|
||||||
|
|
||||||
|
diff_mask = get_differential_mask(
|
||||||
|
conditional_latents,
|
||||||
|
unconditional_latents,
|
||||||
|
threshold=0.2,
|
||||||
|
gradient=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# standardize inpute to std of 1
|
||||||
|
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
|
||||||
|
#
|
||||||
|
# # scale the latents to std of 1
|
||||||
|
# conditional_latents = conditional_latents / combo_std
|
||||||
|
# unconditional_latents = unconditional_latents / combo_std
|
||||||
|
|
||||||
|
unconditional_diff = (unconditional_latents - conditional_latents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# get a -0.5 to 0.5 multiplier for the diff noise
|
||||||
|
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||||
|
# noise_multiplier = noise_multiplier - 0.5
|
||||||
|
noise_multiplier = 1.0
|
||||||
|
|
||||||
|
# unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||||
|
unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||||
|
|
||||||
|
# scale it to the timestep
|
||||||
|
unconditional_diff_noise = sd.add_noise(
|
||||||
|
torch.zeros_like(unconditional_latents),
|
||||||
|
unconditional_diff_noise,
|
||||||
|
timesteps
|
||||||
|
)
|
||||||
|
|
||||||
|
unconditional_diff_noise = unconditional_diff_noise * 0.2
|
||||||
|
|
||||||
|
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
|
||||||
|
|
||||||
|
baseline_noisy_latents = sd.add_noise(
|
||||||
|
unconditional_latents,
|
||||||
|
noise,
|
||||||
|
timesteps
|
||||||
|
).detach()
|
||||||
|
|
||||||
|
target_noise = noise + unconditional_diff_noise
|
||||||
|
noisy_latents = sd.add_noise(
|
||||||
|
conditional_latents,
|
||||||
|
target_noise,
|
||||||
|
# noise,
|
||||||
|
timesteps
|
||||||
|
).detach()
|
||||||
|
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||||
|
sd.network.is_active = False
|
||||||
|
sd.unet.eval()
|
||||||
|
|
||||||
|
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||||
|
# This acts as our control to preserve the unaltered parts of the image.
|
||||||
|
baseline_prediction = sd.predict_noise(
|
||||||
|
latents=baseline_noisy_latents.to(device, dtype=dtype).detach(),
|
||||||
|
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||||
|
timestep=timesteps,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
**pred_kwargs # adapter residuals in here
|
||||||
|
).detach().requires_grad_(False)
|
||||||
|
|
||||||
|
# turn the LoRA network back on.
|
||||||
|
sd.unet.train()
|
||||||
|
sd.network.is_active = True
|
||||||
|
|
||||||
|
sd.network.multiplier = network_weight_list
|
||||||
|
|
||||||
|
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
|
||||||
|
# masked_noise = noise * diff_mask
|
||||||
|
# pred_target = unmasked_noise + unconditional_diff_noise
|
||||||
|
|
||||||
|
# do our prediction with LoRA active on the scaled guidance latents
|
||||||
|
prediction = sd.predict_noise(
|
||||||
|
latents=noisy_latents.to(device, dtype=dtype).detach(),
|
||||||
|
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||||
|
timestep=timesteps,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
**pred_kwargs # adapter residuals in here
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
baselined_prediction = prediction - baseline_prediction
|
||||||
|
|
||||||
|
guidance_loss = torch.nn.functional.mse_loss(
|
||||||
|
baselined_prediction.float(),
|
||||||
|
# unconditional_diff_noise.float(),
|
||||||
|
unconditional_diff_noise.float(),
|
||||||
|
reduction="none"
|
||||||
|
)
|
||||||
|
guidance_loss = guidance_loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
guidance_loss = guidance_loss.mean()
|
||||||
|
|
||||||
|
|
||||||
|
# do the masked noise prediction
|
||||||
|
masked_noise_loss = torch.nn.functional.mse_loss(
|
||||||
|
prediction.float(),
|
||||||
|
target_noise.float(),
|
||||||
|
reduction="none"
|
||||||
|
) * diff_mask
|
||||||
|
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
|
||||||
|
masked_noise_loss = masked_noise_loss.mean()
|
||||||
|
|
||||||
|
|
||||||
|
loss = guidance_loss + masked_noise_loss
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# detach it so parent class can run backward on no grads without throwing error
|
||||||
|
loss = loss.detach()
|
||||||
|
loss.requires_grad_(True)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
def get_guided_loss_polarity(
|
def get_guided_loss_polarity(
|
||||||
noisy_latents: torch.Tensor,
|
noisy_latents: torch.Tensor,
|
||||||
conditional_embeds: PromptEmbeds,
|
conditional_embeds: PromptEmbeds,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from toolkit.config_modules import NetworkConfig
|
|||||||
from toolkit.lorm import extract_conv, extract_linear, count_parameters
|
from toolkit.lorm import extract_conv, extract_linear, count_parameters
|
||||||
from toolkit.metadata import add_model_hash_to_meta
|
from toolkit.metadata import add_model_hash_to_meta
|
||||||
from toolkit.paths import KEYMAPS_ROOT
|
from toolkit.paths import KEYMAPS_ROOT
|
||||||
|
from toolkit.saving import get_lora_keymap_from_model_keymap
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
|
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
|
||||||
@@ -338,6 +339,7 @@ class ToolkitNetworkMixin:
|
|||||||
train_unet: Optional[bool] = True,
|
train_unet: Optional[bool] = True,
|
||||||
is_sdxl=False,
|
is_sdxl=False,
|
||||||
is_v2=False,
|
is_v2=False,
|
||||||
|
is_ssd=False,
|
||||||
network_config: Optional[NetworkConfig] = None,
|
network_config: Optional[NetworkConfig] = None,
|
||||||
is_lorm=False,
|
is_lorm=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -348,6 +350,7 @@ class ToolkitNetworkMixin:
|
|||||||
self._multiplier: float = 1.0
|
self._multiplier: float = 1.0
|
||||||
self.is_active: bool = False
|
self.is_active: bool = False
|
||||||
self.is_sdxl = is_sdxl
|
self.is_sdxl = is_sdxl
|
||||||
|
self.is_ssd = is_ssd
|
||||||
self.is_v2 = is_v2
|
self.is_v2 = is_v2
|
||||||
self.is_merged_in = False
|
self.is_merged_in = False
|
||||||
self.is_lorm = is_lorm
|
self.is_lorm = is_lorm
|
||||||
@@ -357,14 +360,25 @@ class ToolkitNetworkMixin:
|
|||||||
self.can_merge_in = not is_lorm
|
self.can_merge_in = not is_lorm
|
||||||
|
|
||||||
def get_keymap(self: Network):
|
def get_keymap(self: Network):
|
||||||
if self.is_sdxl:
|
use_weight_mapping = False
|
||||||
|
|
||||||
|
if self.is_ssd:
|
||||||
|
keymap_tail = 'ssd'
|
||||||
|
use_weight_mapping = True
|
||||||
|
elif self.is_sdxl:
|
||||||
keymap_tail = 'sdxl'
|
keymap_tail = 'sdxl'
|
||||||
elif self.is_v2:
|
elif self.is_v2:
|
||||||
keymap_tail = 'sd2'
|
keymap_tail = 'sd2'
|
||||||
else:
|
else:
|
||||||
keymap_tail = 'sd1'
|
keymap_tail = 'sd1'
|
||||||
|
# todo double check this
|
||||||
|
use_weight_mapping = True
|
||||||
|
|
||||||
# load keymap
|
# load keymap
|
||||||
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
|
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
|
||||||
|
if use_weight_mapping:
|
||||||
|
keymap_name = f"stable_diffusion_{keymap_tail}.json"
|
||||||
|
|
||||||
keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name)
|
keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name)
|
||||||
|
|
||||||
keymap = None
|
keymap = None
|
||||||
@@ -373,6 +387,10 @@ class ToolkitNetworkMixin:
|
|||||||
with open(keymap_path, 'r') as f:
|
with open(keymap_path, 'r') as f:
|
||||||
keymap = json.load(f)['ldm_diffusers_keymap']
|
keymap = json.load(f)['ldm_diffusers_keymap']
|
||||||
|
|
||||||
|
if use_weight_mapping and keymap is not None:
|
||||||
|
# get keymap from weights
|
||||||
|
keymap = get_lora_keymap_from_model_keymap(keymap)
|
||||||
|
|
||||||
return keymap
|
return keymap
|
||||||
|
|
||||||
def save_weights(
|
def save_weights(
|
||||||
|
|||||||
@@ -206,6 +206,7 @@ def load_t2i_model(
|
|||||||
|
|
||||||
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
|
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
|
||||||
|
|
||||||
|
|
||||||
def save_ip_adapter_from_diffusers(
|
def save_ip_adapter_from_diffusers(
|
||||||
combined_state_dict: 'OrderedDict',
|
combined_state_dict: 'OrderedDict',
|
||||||
output_file: str,
|
output_file: str,
|
||||||
@@ -241,3 +242,58 @@ def load_ip_adapter_model(
|
|||||||
return combined_state_dict
|
return combined_state_dict
|
||||||
else:
|
else:
|
||||||
return torch.load(path_to_file, map_location=device)
|
return torch.load(path_to_file, map_location=device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict':
|
||||||
|
lora_keymap = OrderedDict()
|
||||||
|
|
||||||
|
# see if we have dual text encoders " a key that starts with conditioner.embedders.1
|
||||||
|
has_dual_text_encoders = False
|
||||||
|
for key in model_keymap:
|
||||||
|
if key.startswith('conditioner.embedders.1'):
|
||||||
|
has_dual_text_encoders = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# map through the keys and values
|
||||||
|
for key, value in model_keymap.items():
|
||||||
|
# ignore bias weights
|
||||||
|
if key.endswith('bias'):
|
||||||
|
continue
|
||||||
|
if key.endswith('.weight'):
|
||||||
|
# remove the .weight
|
||||||
|
key = key[:-7]
|
||||||
|
if value.endswith(".weight"):
|
||||||
|
# remove the .weight
|
||||||
|
value = value[:-7]
|
||||||
|
|
||||||
|
# unet for all
|
||||||
|
key = key.replace('model.diffusion_model', 'lora_unet')
|
||||||
|
if value.startswith('unet'):
|
||||||
|
value = f"lora_{value}"
|
||||||
|
|
||||||
|
# text encoder
|
||||||
|
if has_dual_text_encoders:
|
||||||
|
key = key.replace('conditioner.embedders.0', 'lora_te1')
|
||||||
|
key = key.replace('conditioner.embedders.1', 'lora_te2')
|
||||||
|
if value.startswith('te0') or value.startswith('te1'):
|
||||||
|
value = f"lora_{value}"
|
||||||
|
value.replace('lora_te1', 'lora_te2')
|
||||||
|
value.replace('lora_te0', 'lora_te1')
|
||||||
|
|
||||||
|
key = key.replace('cond_stage_model.transformer', 'lora_te')
|
||||||
|
|
||||||
|
if value.startswith('te_'):
|
||||||
|
value = f"lora_{value}"
|
||||||
|
|
||||||
|
# replace periods with underscores
|
||||||
|
key = key.replace('.', '_')
|
||||||
|
value = value.replace('.', '_')
|
||||||
|
|
||||||
|
# add all the weights
|
||||||
|
lora_keymap[f"{key}.lora_down.weight"] = f"{value}.lora_down.weight"
|
||||||
|
lora_keymap[f"{key}.lora_down.bias"] = f"{value}.lora_down.bias"
|
||||||
|
lora_keymap[f"{key}.lora_up.weight"] = f"{value}.lora_up.weight"
|
||||||
|
lora_keymap[f"{key}.lora_up.bias"] = f"{value}.lora_up.bias"
|
||||||
|
lora_keymap[f"{key}.alpha"] = f"{value}.alpha"
|
||||||
|
|
||||||
|
return lora_keymap
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
|||||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||||
StableDiffusionXLImg2ImgPipeline
|
StableDiffusionXLImg2ImgPipeline, LCMScheduler
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import \
|
from diffusers import \
|
||||||
AutoencoderKL, \
|
AutoencoderKL, \
|
||||||
@@ -279,6 +279,20 @@ class StableDiffusion:
|
|||||||
self.load_refiner()
|
self.load_refiner()
|
||||||
self.is_loaded = True
|
self.is_loaded = True
|
||||||
|
|
||||||
|
def te_train(self):
|
||||||
|
if isinstance(self.text_encoder, list):
|
||||||
|
for te in self.text_encoder:
|
||||||
|
te.train()
|
||||||
|
else:
|
||||||
|
self.text_encoder.train()
|
||||||
|
|
||||||
|
def te_eval(self):
|
||||||
|
if isinstance(self.text_encoder, list):
|
||||||
|
for te in self.text_encoder:
|
||||||
|
te.eval()
|
||||||
|
else:
|
||||||
|
self.text_encoder.eval()
|
||||||
|
|
||||||
def load_refiner(self):
|
def load_refiner(self):
|
||||||
# for now, we are just going to rely on the TE from the base model
|
# for now, we are just going to rely on the TE from the base model
|
||||||
# which is TE2 for SDXL and TE for SD (no refiner currently)
|
# which is TE2 for SDXL and TE for SD (no refiner currently)
|
||||||
@@ -721,6 +735,7 @@ class StableDiffusion:
|
|||||||
add_time_ids=None,
|
add_time_ids=None,
|
||||||
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||||
|
is_input_scaled=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -764,6 +779,8 @@ class StableDiffusion:
|
|||||||
|
|
||||||
|
|
||||||
def scale_model_input(model_input, timestep_tensor):
|
def scale_model_input(model_input, timestep_tensor):
|
||||||
|
if is_input_scaled:
|
||||||
|
return model_input
|
||||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||||
out_chunks = []
|
out_chunks = []
|
||||||
@@ -859,7 +876,7 @@ class StableDiffusion:
|
|||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
latent_model_input,
|
latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||||
timestep,
|
timestep,
|
||||||
encoder_hidden_states=text_embeddings.text_embeds,
|
encoder_hidden_states=text_embeddings.text_embeds,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
@@ -903,7 +920,7 @@ class StableDiffusion:
|
|||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
latent_model_input,
|
latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||||
timestep,
|
timestep,
|
||||||
encoder_hidden_states=text_embeddings.text_embeds,
|
encoder_hidden_states=text_embeddings.text_embeds,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -924,6 +941,15 @@ class StableDiffusion:
|
|||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
def step_scheduler(self, model_input, latent_input, timestep_tensor):
|
def step_scheduler(self, model_input, latent_input, timestep_tensor):
|
||||||
|
# // sometimes they are on the wrong device, no idea why
|
||||||
|
if isinstance(self.noise_scheduler, DDPMScheduler) or isinstance(self.noise_scheduler, LCMScheduler):
|
||||||
|
try:
|
||||||
|
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||||
|
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
||||||
|
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||||
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
|
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
|
||||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||||
@@ -955,10 +981,12 @@ class StableDiffusion:
|
|||||||
add_time_ids=None,
|
add_time_ids=None,
|
||||||
bleed_ratio: float = 0.5,
|
bleed_ratio: float = 0.5,
|
||||||
bleed_latents: torch.FloatTensor = None,
|
bleed_latents: torch.FloatTensor = None,
|
||||||
|
is_input_scaled=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
||||||
|
|
||||||
|
|
||||||
for timestep in tqdm(timesteps_to_run, leave=False):
|
for timestep in tqdm(timesteps_to_run, leave=False):
|
||||||
timestep = timestep.unsqueeze_(0)
|
timestep = timestep.unsqueeze_(0)
|
||||||
noise_pred = self.predict_noise(
|
noise_pred = self.predict_noise(
|
||||||
@@ -967,6 +995,7 @@ class StableDiffusion:
|
|||||||
timestep,
|
timestep,
|
||||||
guidance_scale=guidance_scale,
|
guidance_scale=guidance_scale,
|
||||||
add_time_ids=add_time_ids,
|
add_time_ids=add_time_ids,
|
||||||
|
is_input_scaled=is_input_scaled,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# some schedulers need to run separately, so do that. (euler for example)
|
# some schedulers need to run separately, so do that. (euler for example)
|
||||||
@@ -977,6 +1006,9 @@ class StableDiffusion:
|
|||||||
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
|
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
|
||||||
latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
|
latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
|
||||||
|
|
||||||
|
# only skip first scaling
|
||||||
|
is_input_scaled = False
|
||||||
|
|
||||||
# return latents_steps
|
# return latents_steps
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user