diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f3a8a8ae..67624c98 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -2,13 +2,15 @@ from collections import OrderedDict from typing import Union, Literal, List from diffusers import T2IAdapter +import torch.functional as F from toolkit import train_tools from toolkit.basic import value_map, adain, get_mean_std from toolkit.config_modules import GuidanceConfig from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO +from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss from toolkit.image_utils import show_tensors, show_latents from toolkit.ip_adapter import IPAdapter -from toolkit.prompt_utils import PromptEmbeds +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \ apply_learnable_snr_gos, LearnableSNRGamma @@ -35,6 +37,7 @@ class SDTrainer(BaseSDTrainProcess): self.assistant_adapter: Union['T2IAdapter', None] self.do_prior_prediction = False self.do_long_prompts = False + self.do_guided_loss = False if self.train_config.inverted_mask_prior: self.do_prior_prediction = True @@ -186,6 +189,33 @@ class SDTrainer(BaseSDTrainProcess): batch: 'DataLoaderBatchDTO', noise: torch.Tensor, **kwargs + ): + loss = get_guidance_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + sd=self.sd, + **kwargs + ) + + return loss + + def get_guided_loss_targeted_polarity( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + **kwargs ): with torch.no_grad(): # Perform targeted guidance (working title) @@ -194,23 +224,28 @@ class SDTrainer(BaseSDTrainProcess): conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach() unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach() - unconditional_diff = unconditional_latents - conditional_latents - conditional_diff = conditional_latents - unconditional_latents + mean_latents = (conditional_latents + unconditional_latents) / 2.0 + + unconditional_diff = (unconditional_latents - mean_latents) + conditional_diff = (conditional_latents - mean_latents) # we need to determine the amount of signal and noise that would be present at the current timestep - conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps) - unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps) + # conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps) + # unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps) + # unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps) + # conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps) + # unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps) - target_noise = noise + unconditional_signal + # target_noise = noise + unconditional_signal conditional_noisy_latents = self.sd.add_noise( - unconditional_latents + conditional_signal, - target_noise, + mean_latents, + noise, timesteps ).detach() unconditional_noisy_latents = self.sd.add_noise( - unconditional_latents, + mean_latents, noise, timesteps ).detach() @@ -229,31 +264,210 @@ class SDTrainer(BaseSDTrainProcess): **pred_kwargs # adapter residuals in here ).detach() - # turn the LoRA network back on. - self.sd.unet.train() - self.network.is_active = True - self.network.multiplier = network_weight_list + # double up everything to run it through all at once + cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0) + cat_timesteps = torch.cat([timesteps, timesteps], dim=0) + + # since we are dividing the polarity from the middle out, we need to double our network + # weights on training since the convergent point will be at half network strength + + negative_network_weights = [weight * -2.0 for weight in network_weight_list] + positive_network_weights = [weight * 2.0 for weight in network_weight_list] + cat_network_weight_list = positive_network_weights + negative_network_weights + + # turn the LoRA network back on. + self.sd.unet.train() + self.network.is_active = True + + self.network.multiplier = cat_network_weight_list # do our prediction with LoRA active on the scaled guidance latents prediction = self.sd.predict_noise( - latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), - conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), - timestep=timesteps, + latents=cat_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=cat_timesteps, guidance_scale=1.0, **pred_kwargs # adapter residuals in here ) - # remove baseline from our prediction to extract our differential prediction - prediction = prediction - baseline_prediction + pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) - loss = torch.nn.functional.mse_loss( - prediction.float(), - unconditional_signal.float(), + pred_pos = pred_pos - baseline_prediction + pred_neg = pred_neg - baseline_prediction + + pred_loss = torch.nn.functional.mse_loss( + pred_pos.float(), + unconditional_diff.float(), reduction="none" ) - loss = loss.mean([1, 2, 3]) + pred_loss = pred_loss.mean([1, 2, 3]) - loss = self.apply_snr(loss, timesteps) + pred_neg_loss = torch.nn.functional.mse_loss( + pred_neg.float(), + conditional_diff.float(), + reduction="none" + ) + pred_neg_loss = pred_neg_loss.mean([1, 2, 3]) + + loss = (pred_loss + pred_neg_loss) / 2.0 + + # loss = self.apply_snr(loss, timesteps) + loss = loss.mean() + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + + def get_guided_loss_masked_polarity( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + **kwargs + ): + with torch.no_grad(): + # Perform targeted guidance (working title) + dtype = get_torch_dtype(self.train_config.dtype) + + conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach() + inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents) + + mean_latents = (conditional_latents + unconditional_latents) / 2.0 + + # unconditional_diff = (unconditional_latents - mean_latents) + # conditional_diff = (conditional_latents - mean_latents) + + # we need to determine the amount of signal and noise that would be present at the current timestep + # conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps) + # unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps) + # unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps) + # conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps) + # unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps) + + # make a differential mask + differential_mask = torch.abs(conditional_latents - unconditional_latents) + max_differential = \ + differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + differential_scaler = 1.0 / max_differential + differential_mask = differential_mask * differential_scaler + spread_point = 0.1 + # adjust mask to amplify the differential at 0.1 + differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point + # clip it + differential_mask = torch.clamp(differential_mask, 0.0, 1.0) + + # target_noise = noise + unconditional_signal + + conditional_noisy_latents = self.sd.add_noise( + conditional_latents, + noise, + timesteps + ).detach() + + unconditional_noisy_latents = self.sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + + inverse_noisy_latents = self.sd.add_noise( + inverse_latents, + noise, + timesteps + ).detach() + + # Disable the LoRA network so we can predict parent network knowledge without it + self.network.is_active = False + self.sd.unet.eval() + + # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. + # This acts as our control to preserve the unaltered parts of the image. + # baseline_prediction = self.sd.predict_noise( + # latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), + # conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), + # timestep=timesteps, + # guidance_scale=1.0, + # **pred_kwargs # adapter residuals in here + # ).detach() + + # double up everything to run it through all at once + cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) + cat_timesteps = torch.cat([timesteps, timesteps], dim=0) + + # since we are dividing the polarity from the middle out, we need to double our network + # weights on training since the convergent point will be at half network strength + + negative_network_weights = [weight * -1.0 for weight in network_weight_list] + positive_network_weights = [weight * 1.0 for weight in network_weight_list] + cat_network_weight_list = positive_network_weights + negative_network_weights + + # turn the LoRA network back on. + self.sd.unet.train() + self.network.is_active = True + + self.network.multiplier = cat_network_weight_list + + # do our prediction with LoRA active on the scaled guidance latents + prediction = self.sd.predict_noise( + latents=cat_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=cat_timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) + + # create a loss to balance the mean to 0 between the two predictions + differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0 + + # pred_pos = pred_pos - baseline_prediction + # pred_neg = pred_neg - baseline_prediction + + pred_loss = torch.nn.functional.mse_loss( + pred_pos.float(), + noise.float(), + reduction="none" + ) + # apply mask + pred_loss = pred_loss * (1.0 + differential_mask) + pred_loss = pred_loss.mean([1, 2, 3]) + + pred_neg_loss = torch.nn.functional.mse_loss( + pred_neg.float(), + noise.float(), + reduction="none" + ) + # apply inverse mask + pred_neg_loss = pred_neg_loss * (1.0 - differential_mask) + pred_neg_loss = pred_neg_loss.mean([1, 2, 3]) + + # make a loss to balance to losses of the pos and neg so they are equal + # differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss) + # + # differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss + # + # # add a multiplier to balancing losses to make them the top priority + # differential_mean_loss = differential_mean_loss + + # remove the grads from the negative as it is only a balancing loss + # pred_neg_loss = pred_neg_loss.detach() + + # loss = pred_loss + pred_neg_loss + differential_mean_loss + loss = pred_loss + pred_neg_loss + + # loss = self.apply_snr(loss, timesteps) loss = loss.mean() loss.backward() @@ -556,7 +770,7 @@ class SDTrainer(BaseSDTrainProcess): self.before_unet_predict() # do a prior pred if we have an unconditional image, we will swap out the giadance later - if batch.unconditional_latents is not None: + if batch.unconditional_latents is not None or self.do_guided_loss: # do guided loss loss = self.get_guided_loss( noisy_latents=noisy_latents, diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 4d020a3d..8c0a9a2e 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -293,6 +293,10 @@ class BaseSDTrainProcess(BaseTrainProcess): # will end in safetensors or pt embed_files = [f for f in embed_items if f.endswith('.safetensors') or f.endswith('.pt')] + # check for critic files + critic_pattern = f"CRITIC_{self.job.name}_*" + critic_items = glob.glob(os.path.join(self.save_root, critic_pattern)) + # Sort the lists by creation time if they are not empty if safetensors_files: safetensors_files.sort(key=os.path.getctime) @@ -302,6 +306,8 @@ class BaseSDTrainProcess(BaseTrainProcess): directories.sort(key=os.path.getctime) if embed_files: embed_files.sort(key=os.path.getctime) + if critic_items: + critic_items.sort(key=os.path.getctime) # Combine and sort the lists combined_items = safetensors_files + directories + pt_files @@ -313,8 +319,9 @@ class BaseSDTrainProcess(BaseTrainProcess): pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else [] directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else [] embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else [] + critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else [] - items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove # remove all but the latest max_step_saves_to_keep # items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep] @@ -1041,8 +1048,9 @@ class BaseSDTrainProcess(BaseTrainProcess): train_text_encoder=self.train_config.train_text_encoder, conv_lora_dim=self.network_config.conv, conv_alpha=self.network_config.conv_alpha, - is_sdxl=self.model_config.is_xl, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, is_v2=self.model_config.is_v2, + is_ssd=self.model_config.is_ssd, dropout=self.network_config.dropout, use_text_encoder_1=self.model_config.use_text_encoder_1, use_text_encoder_2=self.model_config.use_text_encoder_2, diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 9f18e6bc..88b9d104 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -371,7 +371,7 @@ class TrainSliderProcess(BaseSDTrainProcess): # ger a random number of steps timesteps_to = torch.randint( - 1, self.train_config.max_denoising_steps, (1,) + 1, self.train_config.max_denoising_steps - 1, (1,) ).item() # get noise @@ -389,7 +389,8 @@ class TrainSliderProcess(BaseSDTrainProcess): assert not self.network.is_active self.sd.unet.eval() # pass the multiplier list to the network - self.network.multiplier = prompt_pair.multiplier_list + # double up since we are doing cfg + self.network.multiplier = prompt_pair.multiplier_list + prompt_pair.multiplier_list denoised_latents = self.sd.diffuse_some_steps( latents, # pass simple noise latents train_tools.concat_prompt_embeddings( @@ -507,7 +508,7 @@ class TrainSliderProcess(BaseSDTrainProcess): for anchor_chunk, denoised_latent_chunk, anchor_target_noise_chunk in zip( anchor_chunks, denoised_latent_chunks, anchor_target_noise_chunks ): - self.network.multiplier = anchor_chunk.multiplier_list + self.network.multiplier = anchor_chunk.multiplier_list + anchor_chunk.multiplier_list anchor_pred_noise = get_noise_pred( anchor_chunk.neg_prompt, anchor_chunk.prompt, 1, current_timestep, denoised_latent_chunk @@ -582,7 +583,7 @@ class TrainSliderProcess(BaseSDTrainProcess): mask_multiplier_chunks, unmasked_target_chunks ): - self.network.multiplier = prompt_pair_chunk.multiplier_list + self.network.multiplier = prompt_pair_chunk.multiplier_list + prompt_pair_chunk.multiplier_list target_latents = get_noise_pred( prompt_pair_chunk.positive_target, prompt_pair_chunk.target_class, @@ -611,6 +612,7 @@ class TrainSliderProcess(BaseSDTrainProcess): offset_neutral = neutral_latents_chunk # offsets are already adjusted on a per-batch basis offset_neutral += offset + offset_neutral = offset_neutral.detach().requires_grad_(False) # 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none") diff --git a/scripts/generate_sampler_step_scales.py b/scripts/generate_sampler_step_scales.py new file mode 100644 index 00000000..11efb318 --- /dev/null +++ b/scripts/generate_sampler_step_scales.py @@ -0,0 +1,20 @@ +import argparse +import torch +import os +from diffusers import StableDiffusionPipeline +import sys + +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# add project root to path +sys.path.append(PROJECT_ROOT) + +SAMPLER_SCALES_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'samplers_scales') + + +parser = argparse.ArgumentParser(description='Process some images.') +add_arg = parser.add_argument +add_arg('--model', type=str, required=True, help='Path to model') +add_arg('--sampler', type=str, required=True, help='Name of sampler') + +args = parser.parse_args() + diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 8fdb2670..01525963 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1,6 +1,6 @@ import os import time -from typing import List, Optional, Literal, Union +from typing import List, Optional, Literal, Union, TYPE_CHECKING import random import torch @@ -11,6 +11,8 @@ ImgExt = Literal['jpg', 'png', 'webp'] SaveFormat = Literal['safetensors', 'diffusers'] +if TYPE_CHECKING: + from toolkit.guidance import GuidanceType class SaveConfig: def __init__(self, **kwargs): @@ -400,6 +402,7 @@ class DatasetConfig: if legacy_caption_type: self.caption_ext = legacy_caption_type self.caption_type = self.caption_ext + self.guidance_type: GuidanceType = kwargs.get('guidance_type', 'targeted_polarity') def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: diff --git a/toolkit/guidance.py b/toolkit/guidance.py index 6ace58d6..cb22807a 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -1,5 +1,7 @@ import torch from typing import Literal + +from toolkit.basic import value_map from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds from toolkit.stable_diffusion_model import StableDiffusion @@ -8,13 +10,16 @@ from toolkit.train_tools import get_torch_dtype GuidanceType = Literal["targeted", "polarity", "targeted_polarity"] DIFFERENTIAL_SCALER = 0.2 + + # DIFFERENTIAL_SCALER = 0.25 def get_differential_mask( conditional_latents: torch.Tensor, unconditional_latents: torch.Tensor, - threshold: float = 0.2 + threshold: float = 0.2, + gradient: bool = False, ): # make a differential mask differential_mask = torch.abs(conditional_latents - unconditional_latents) @@ -23,12 +28,27 @@ def get_differential_mask( differential_scaler = 1.0 / max_differential differential_mask = differential_mask * differential_scaler - # make everything less than 0.2 be 0.0 and everything else be 1.0 - differential_mask = torch.where( - differential_mask < threshold, - torch.zeros_like(differential_mask), - torch.ones_like(differential_mask) - ) + if gradient: + # wew need to scale it to 0-1 + # differential_mask = differential_mask - differential_mask.min() + # differential_mask = differential_mask / differential_mask.max() + # add 0.2 threshold to both sides and clip + differential_mask = value_map( + differential_mask, + differential_mask.min(), + differential_mask.max(), + 0 - threshold, + 1 + threshold + ) + differential_mask = torch.clamp(differential_mask, 0.0, 1.0) + else: + + # make everything less than 0.2 be 0.0 and everything else be 1.0 + differential_mask = torch.where( + differential_mask < threshold, + torch.zeros_like(differential_mask), + torch.ones_like(differential_mask) + ) return differential_mask @@ -47,7 +67,6 @@ def get_targeted_polarity_loss( dtype = get_torch_dtype(sd.torch_dtype) device = sd.device_torch with torch.no_grad(): - conditional_latents = batch.latents.to(device, dtype=dtype).detach() unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() @@ -164,7 +183,7 @@ def get_targeted_polarity_loss( return loss -# This targets only the positive differential + # targeted def get_targeted_guidance_loss( noisy_latents: torch.Tensor, @@ -183,35 +202,71 @@ def get_targeted_guidance_loss( dtype = get_torch_dtype(sd.torch_dtype) device = sd.device_torch + conditional_latents = batch.latents.to(device, dtype=dtype).detach() unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() - unconditional_diff = (unconditional_latents - conditional_latents) + # apply random offset to both latents + offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype) + offset = offset * 0.1 + conditional_latents = conditional_latents + offset + unconditional_latents = unconditional_latents + offset + + # get random scale 0f 0.8 to 1.2 + scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype) + scale = scale * 0.4 + scale = scale + 0.8 + conditional_latents = conditional_latents * scale + unconditional_latents = unconditional_latents * scale diff_mask = get_differential_mask( conditional_latents, unconditional_latents, - threshold=0.1 + threshold=0.2, + gradient=True ) - # this is a magic number I spent weeks deducing. It works and I have no idea why. - # unconditional_diff_noise = unconditional_diff * DIFFERENTIAL_SCALER + # standardize inpute to std of 1 + # combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True) + # + # # scale the latents to std of 1 + # conditional_latents = conditional_latents / combo_std + # unconditional_latents = unconditional_latents / combo_std - inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True) - noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True) - diff_noise_scaler = noise_abs_mean / inputs_abs_mean - unconditional_diff_noise = unconditional_diff * diff_noise_scaler + unconditional_diff = (unconditional_latents - conditional_latents) + + + + # get a -0.5 to 0.5 multiplier for the diff noise + # noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype) + # noise_multiplier = noise_multiplier - 0.5 + noise_multiplier = 1.0 + + # unconditional_diff_noise = unconditional_diff * noise_multiplier + unconditional_diff_noise = unconditional_diff * noise_multiplier + + # scale it to the timestep + unconditional_diff_noise = sd.add_noise( + torch.zeros_like(unconditional_latents), + unconditional_diff_noise, + timesteps + ) + + unconditional_diff_noise = unconditional_diff_noise * 0.2 + + unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False) baseline_noisy_latents = sd.add_noise( - conditional_latents, + unconditional_latents, noise, timesteps ).detach() + target_noise = noise + unconditional_diff_noise noisy_latents = sd.add_noise( conditional_latents, - # noise + unconditional_diff_noise, - noise, + target_noise, + # noise, timesteps ).detach() # Disable the LoRA network so we can predict parent network knowledge without it @@ -226,15 +281,20 @@ def get_targeted_guidance_loss( timestep=timesteps, guidance_scale=1.0, **pred_kwargs # adapter residuals in here - ).detach() + ).detach().requires_grad_(False) + + # determine the error for the baseline prediction + baseline_prediction_error = baseline_prediction - noise + + # turn the LoRA network back on. sd.unet.train() sd.network.is_active = True sd.network.multiplier = network_weight_list - unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask) - masked_noise = noise * diff_mask + # unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask) + # masked_noise = noise * diff_mask # pred_target = unmasked_noise + unconditional_diff_noise # do our prediction with LoRA active on the scaled guidance latents @@ -246,30 +306,32 @@ def get_targeted_guidance_loss( **pred_kwargs # adapter residuals in here ) - prediction = prediction - unmasked_baseline_prediction - # prediction = prediction - baseline_prediction - baseline_loss = torch.nn.functional.mse_loss( - baseline_prediction.float(), - noise.float(), + + baselined_prediction = prediction - baseline_prediction + + guidance_loss = torch.nn.functional.mse_loss( + baselined_prediction.float(), + # unconditional_diff_noise.float(), + unconditional_diff_noise.float(), reduction="none" ) - baseline_loss = baseline_loss * (1.0 - diff_mask) - baseline_loss = baseline_loss.mean([1, 2, 3]) + guidance_loss = guidance_loss.mean([1, 2, 3]) - # loss = torch.nn.functional.l1_loss( - loss = torch.nn.functional.mse_loss( + guidance_loss = guidance_loss.mean() + + + # do the masked noise prediction + masked_noise_loss = torch.nn.functional.mse_loss( prediction.float(), - masked_noise.float(), + target_noise.float(), reduction="none" - ) - loss = loss * diff_mask - loss = loss.mean([1, 2, 3]) - primary_loss_scaler = 1.0 + ) * diff_mask + masked_noise_loss = masked_noise_loss.mean([1, 2, 3]) + masked_noise_loss = masked_noise_loss.mean() - loss = (loss * primary_loss_scaler) + baseline_loss - loss = loss.mean() + loss = guidance_loss + masked_noise_loss loss.backward() @@ -280,8 +342,158 @@ def get_targeted_guidance_loss( return loss +def get_targeted_guidance_loss_WIP( + noisy_latents: torch.Tensor, + conditional_embeds: 'PromptEmbeds', + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + **kwargs +): + with torch.no_grad(): + # Perform targeted guidance (working title) + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + # apply random offset to both latents + offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype) + offset = offset * 0.1 + conditional_latents = conditional_latents + offset + unconditional_latents = unconditional_latents + offset + + # get random scale 0f 0.8 to 1.2 + scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype) + scale = scale * 0.4 + scale = scale + 0.8 + conditional_latents = conditional_latents * scale + unconditional_latents = unconditional_latents * scale + + diff_mask = get_differential_mask( + conditional_latents, + unconditional_latents, + threshold=0.2, + gradient=True + ) + + # standardize inpute to std of 1 + # combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True) + # + # # scale the latents to std of 1 + # conditional_latents = conditional_latents / combo_std + # unconditional_latents = unconditional_latents / combo_std + + unconditional_diff = (unconditional_latents - conditional_latents) + + + + # get a -0.5 to 0.5 multiplier for the diff noise + # noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype) + # noise_multiplier = noise_multiplier - 0.5 + noise_multiplier = 1.0 + + # unconditional_diff_noise = unconditional_diff * noise_multiplier + unconditional_diff_noise = unconditional_diff * noise_multiplier + + # scale it to the timestep + unconditional_diff_noise = sd.add_noise( + torch.zeros_like(unconditional_latents), + unconditional_diff_noise, + timesteps + ) + + unconditional_diff_noise = unconditional_diff_noise * 0.2 + + unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False) + + baseline_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + + target_noise = noise + unconditional_diff_noise + noisy_latents = sd.add_noise( + conditional_latents, + target_noise, + # noise, + timesteps + ).detach() + # Disable the LoRA network so we can predict parent network knowledge without it + sd.network.is_active = False + sd.unet.eval() + + # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. + # This acts as our control to preserve the unaltered parts of the image. + baseline_prediction = sd.predict_noise( + latents=baseline_noisy_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ).detach().requires_grad_(False) + + # turn the LoRA network back on. + sd.unet.train() + sd.network.is_active = True + + sd.network.multiplier = network_weight_list + + # unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask) + # masked_noise = noise * diff_mask + # pred_target = unmasked_noise + unconditional_diff_noise + + # do our prediction with LoRA active on the scaled guidance latents + prediction = sd.predict_noise( + latents=noisy_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + + + baselined_prediction = prediction - baseline_prediction + + guidance_loss = torch.nn.functional.mse_loss( + baselined_prediction.float(), + # unconditional_diff_noise.float(), + unconditional_diff_noise.float(), + reduction="none" + ) + guidance_loss = guidance_loss.mean([1, 2, 3]) + + guidance_loss = guidance_loss.mean() + + + # do the masked noise prediction + masked_noise_loss = torch.nn.functional.mse_loss( + prediction.float(), + target_noise.float(), + reduction="none" + ) * diff_mask + masked_noise_loss = masked_noise_loss.mean([1, 2, 3]) + masked_noise_loss = masked_noise_loss.mean() + + + loss = guidance_loss + masked_noise_loss + + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + def get_guided_loss_polarity( noisy_latents: torch.Tensor, conditional_embeds: PromptEmbeds, diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 0644f494..37f9f83c 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -13,6 +13,7 @@ from toolkit.config_modules import NetworkConfig from toolkit.lorm import extract_conv, extract_linear, count_parameters from toolkit.metadata import add_model_hash_to_meta from toolkit.paths import KEYMAPS_ROOT +from toolkit.saving import get_lora_keymap_from_model_keymap if TYPE_CHECKING: from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule @@ -338,6 +339,7 @@ class ToolkitNetworkMixin: train_unet: Optional[bool] = True, is_sdxl=False, is_v2=False, + is_ssd=False, network_config: Optional[NetworkConfig] = None, is_lorm=False, **kwargs @@ -348,6 +350,7 @@ class ToolkitNetworkMixin: self._multiplier: float = 1.0 self.is_active: bool = False self.is_sdxl = is_sdxl + self.is_ssd = is_ssd self.is_v2 = is_v2 self.is_merged_in = False self.is_lorm = is_lorm @@ -357,14 +360,25 @@ class ToolkitNetworkMixin: self.can_merge_in = not is_lorm def get_keymap(self: Network): - if self.is_sdxl: + use_weight_mapping = False + + if self.is_ssd: + keymap_tail = 'ssd' + use_weight_mapping = True + elif self.is_sdxl: keymap_tail = 'sdxl' elif self.is_v2: keymap_tail = 'sd2' else: keymap_tail = 'sd1' + # todo double check this + use_weight_mapping = True + # load keymap keymap_name = f"stable_diffusion_locon_{keymap_tail}.json" + if use_weight_mapping: + keymap_name = f"stable_diffusion_{keymap_tail}.json" + keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name) keymap = None @@ -373,6 +387,10 @@ class ToolkitNetworkMixin: with open(keymap_path, 'r') as f: keymap = json.load(f)['ldm_diffusers_keymap'] + if use_weight_mapping and keymap is not None: + # get keymap from weights + keymap = get_lora_keymap_from_model_keymap(keymap) + return keymap def save_weights( diff --git a/toolkit/saving.py b/toolkit/saving.py index 23862cb4..5c7b37e1 100644 --- a/toolkit/saving.py +++ b/toolkit/saving.py @@ -206,6 +206,7 @@ def load_t2i_model( IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter'] + def save_ip_adapter_from_diffusers( combined_state_dict: 'OrderedDict', output_file: str, @@ -241,3 +242,58 @@ def load_ip_adapter_model( return combined_state_dict else: return torch.load(path_to_file, map_location=device) + + +def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict': + lora_keymap = OrderedDict() + + # see if we have dual text encoders " a key that starts with conditioner.embedders.1 + has_dual_text_encoders = False + for key in model_keymap: + if key.startswith('conditioner.embedders.1'): + has_dual_text_encoders = True + break + + # map through the keys and values + for key, value in model_keymap.items(): + # ignore bias weights + if key.endswith('bias'): + continue + if key.endswith('.weight'): + # remove the .weight + key = key[:-7] + if value.endswith(".weight"): + # remove the .weight + value = value[:-7] + + # unet for all + key = key.replace('model.diffusion_model', 'lora_unet') + if value.startswith('unet'): + value = f"lora_{value}" + + # text encoder + if has_dual_text_encoders: + key = key.replace('conditioner.embedders.0', 'lora_te1') + key = key.replace('conditioner.embedders.1', 'lora_te2') + if value.startswith('te0') or value.startswith('te1'): + value = f"lora_{value}" + value.replace('lora_te1', 'lora_te2') + value.replace('lora_te0', 'lora_te1') + + key = key.replace('cond_stage_model.transformer', 'lora_te') + + if value.startswith('te_'): + value = f"lora_{value}" + + # replace periods with underscores + key = key.replace('.', '_') + value = value.replace('.', '_') + + # add all the weights + lora_keymap[f"{key}.lora_down.weight"] = f"{value}.lora_down.weight" + lora_keymap[f"{key}.lora_down.bias"] = f"{value}.lora_down.bias" + lora_keymap[f"{key}.lora_up.weight"] = f"{value}.lora_up.weight" + lora_keymap[f"{key}.lora_up.bias"] = f"{value}.lora_up.bias" + lora_keymap[f"{key}.alpha"] = f"{value}.alpha" + + return lora_keymap diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index c2701bb3..2c7dd286 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -35,7 +35,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \ - StableDiffusionXLImg2ImgPipeline + StableDiffusionXLImg2ImgPipeline, LCMScheduler import diffusers from diffusers import \ AutoencoderKL, \ @@ -279,6 +279,20 @@ class StableDiffusion: self.load_refiner() self.is_loaded = True + def te_train(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.train() + else: + self.text_encoder.train() + + def te_eval(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.eval() + else: + self.text_encoder.eval() + def load_refiner(self): # for now, we are just going to rely on the TE from the base model # which is TE2 for SDXL and TE for SD (no refiner currently) @@ -721,6 +735,7 @@ class StableDiffusion: add_time_ids=None, conditional_embeddings: Union[PromptEmbeds, None] = None, unconditional_embeddings: Union[PromptEmbeds, None] = None, + is_input_scaled=True, **kwargs, ): with torch.no_grad(): @@ -764,6 +779,8 @@ class StableDiffusion: def scale_model_input(model_input, timestep_tensor): + if is_input_scaled: + return model_input mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) out_chunks = [] @@ -859,7 +876,7 @@ class StableDiffusion: # predict the noise residual noise_pred = self.unet( - latent_model_input, + latent_model_input.to(self.device_torch, self.torch_dtype), timestep, encoder_hidden_states=text_embeddings.text_embeds, added_cond_kwargs=added_cond_kwargs, @@ -903,7 +920,7 @@ class StableDiffusion: # predict the noise residual noise_pred = self.unet( - latent_model_input, + latent_model_input.to(self.device_torch, self.torch_dtype), timestep, encoder_hidden_states=text_embeddings.text_embeds, **kwargs, @@ -924,6 +941,15 @@ class StableDiffusion: return noise_pred def step_scheduler(self, model_input, latent_input, timestep_tensor): + # // sometimes they are on the wrong device, no idea why + if isinstance(self.noise_scheduler, DDPMScheduler) or isinstance(self.noise_scheduler, LCMScheduler): + try: + self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) + self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch) + self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch) + except Exception as e: + pass + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) @@ -955,10 +981,12 @@ class StableDiffusion: add_time_ids=None, bleed_ratio: float = 0.5, bleed_latents: torch.FloatTensor = None, + is_input_scaled=False, **kwargs, ): timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] + for timestep in tqdm(timesteps_to_run, leave=False): timestep = timestep.unsqueeze_(0) noise_pred = self.predict_noise( @@ -967,6 +995,7 @@ class StableDiffusion: timestep, guidance_scale=guidance_scale, add_time_ids=add_time_ids, + is_input_scaled=is_input_scaled, **kwargs, ) # some schedulers need to run separately, so do that. (euler for example) @@ -977,6 +1006,9 @@ class StableDiffusion: if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio) + # only skip first scaling + is_input_scaled = False + # return latents_steps return latents