diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 41466732..a64c289f 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -4,6 +4,7 @@ from diffusers import T2IAdapter from toolkit import train_tools from toolkit.basic import value_map +from toolkit.config_modules import GuidanceConfig from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.ip_adapter import IPAdapter from toolkit.prompt_utils import PromptEmbeds @@ -32,7 +33,6 @@ class SDTrainer(BaseSDTrainProcess): super().__init__(process_id, job, config, **kwargs) self.assistant_adapter: Union['T2IAdapter', None] self.do_prior_prediction = False - self.target_class = self.get_conf('target_class', '') if self.train_config.inverted_mask_prior: self.do_prior_prediction = True @@ -187,84 +187,84 @@ class SDTrainer(BaseSDTrainProcess): **kwargs ): with torch.no_grad(): + conditional_noisy_latents = noisy_latents dtype = get_torch_dtype(self.train_config.dtype) - # target class is unconditional - target_class_embeds = self.sd.encode_prompt(self.target_class).detach() if batch.unconditional_latents is not None: - # do the unconditional prediction here instead of a prior prediction - unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(batch.unconditional_latents, noise, - timesteps) + # Encode the unconditional image into latents + unconditional_noisy_latents = self.sd.noise_scheduler.add_noise( + batch.unconditional_latents, noise, timesteps + ) - was_network_active = self.network.is_active + # was_network_active = self.network.is_active self.network.is_active = False self.sd.unet.eval() - guidance_scale = 1.0 + # calculate the differential between our conditional (target image) and out unconditional ("bad" image) + target_differential = unconditional_noisy_latents - conditional_noisy_latents + target_differential = target_differential.detach() - def cfg(uncon, con): - return uncon + guidance_scale * ( - con - uncon - ) + # add the target differential to the target latents as if it were noise with the scheduler scaled to + # the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done + # properly + guidance_latents = self.sd.noise_scheduler.add_noise( + conditional_noisy_latents, + target_differential, + timesteps + ) - target_conditional = self.sd.predict_noise( - latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), + # With LoRA network bypassed, predict noise to get a baseline of what the network + # wants to do with the latents + noise. Pass our target latents here for the input. + target_unconditional = self.sd.predict_noise( + latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), timestep=timesteps, guidance_scale=1.0, **pred_kwargs # adapter residuals in here ).detach() - target_unconditional = self.sd.predict_noise( - latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), - conditional_embeddings=target_class_embeds.to(self.device_torch, dtype=dtype).detach(), - timestep=timesteps, - guidance_scale=1.0, - **pred_kwargs # adapter residuals in here - ).detach() - - neutral_latents = (noisy_latents + unconditional_noisy_latents) / 2.0 - - target_noise = cfg(target_unconditional, target_conditional) - # latents = self.noise_scheduler.step(target_noise, timesteps, noisy_latents, return_dict=False)[0] - - # target_pred = target_pred - noisy_latents + (unconditional_noisy_latents - noise) - - # target_noise_res = noisy_latents - unconditional_noisy_latents - - # target_pred = cfg(unconditional_noisy_latents, target_pred) - # target_pred = target_pred + target_noise_res - - self.network.is_active = True + # turn the LoRA network back on. self.sd.unet.train() + self.network.is_active = True + self.network.multiplier = network_weight_list - prediction = self.sd.predict_noise( - latents=neutral_latents.to(self.device_torch, dtype=dtype).detach(), + # with LoRA active, predict the noise with the scaled differential latents added. This will allow us + # the opportunity to predict the differential + noise that was added to the latents. + prediction_unconditional = self.sd.predict_noise( + latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), timestep=timesteps, guidance_scale=1.0, **pred_kwargs # adapter residuals in here ) - # prediction_res = target_pred - prediction + # remove the baseline conditional prediction. This will leave only the divergence from the baseline and + # the prediction of the added differential noise + prediction_positive = prediction_unconditional - target_unconditional + # for loss, we target ONLY the unscaled differential between our conditional and unconditional latents + # this is the diffusion training process. + # This will guide the network to make identical predictions it previously did for everything EXCEPT our + # differential between the conditional and unconditional images + positive_loss = torch.nn.functional.mse_loss( + prediction_positive.float(), + target_differential.float(), + reduction="none" + ) + positive_loss = positive_loss.mean([1, 2, 3]) + # send it backwards BEFORE switching network polarity + positive_loss = self.apply_snr(positive_loss, timesteps) + positive_loss = positive_loss.mean() + positive_loss.backward() + # loss = positive_loss.detach() + negative_loss.detach() + loss = positive_loss.detach() - # prediction = cfg(prediction, target_pred) + # add a grad so other backward does not fail + loss.requires_grad_(True) - loss = torch.nn.functional.mse_loss(prediction.float(), target_noise.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + # restore network + self.network.multiplier = network_weight_list - if self.train_config.learnable_snr_gos: - # add snr_gamma - loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) - elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001: - # add snr_gamma - loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) - elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: - # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) - - loss = loss.mean() return loss def get_prior_prediction( diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 863e4dca..01e57b1b 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -40,13 +40,13 @@ from toolkit.stable_diffusion_model import StableDiffusion from jobs.process import BaseTrainProcess from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta, \ parse_metadata_from_safetensors -from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma +from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma, apply_learnable_snr_gos, apply_snr_weight import gc from tqdm import tqdm from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ - GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig + GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig def flush(): @@ -94,6 +94,11 @@ class BaseSDTrainProcess(BaseTrainProcess): self.data_loader_reg: Union[DataLoader, None] = None self.trigger_word = self.get_conf('trigger_word', None) + self.guidance_config: Union[GuidanceConfig, None] = None + guidance_config_raw = self.get_conf('guidance', None) + if guidance_config_raw is not None: + self.guidance_config = GuidanceConfig(**guidance_config_raw) + # store is all are cached. Allows us to not load vae if we don't need to self.is_latents_cached = True raw_datasets = self.get_conf('datasets', None) @@ -417,6 +422,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.print(f"Saved to {file_path}") self.clean_up_saves() self.post_save_hook(file_path) + flush() # Called before the model is loaded def hook_before_model_load(self): @@ -501,6 +507,19 @@ class BaseSDTrainProcess(BaseTrainProcess): print("load_weights not implemented for non-network models") return None + def apply_snr(self, seperated_loss, timesteps): + if self.train_config.learnable_snr_gos: + # add snr_gamma + seperated_loss = apply_learnable_snr_gos(seperated_loss, timesteps, self.snr_gos) + elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001: + # add snr_gamma + seperated_loss = apply_snr_weight(seperated_loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) + elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + seperated_loss = apply_snr_weight(seperated_loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + return seperated_loss + def load_lorm(self): latest_save_path = self.get_latest_save_path() if latest_save_path is not None: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b8f8122e..7f0efafb 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -242,6 +242,7 @@ class ModelConfig: self.dtype: str = kwargs.get('dtype', 'float16') self.vae_path = kwargs.get('vae_path', None) self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None) + self._original_refiner_name_or_path = self.refiner_name_or_path self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # only for SDXL models for now @@ -286,6 +287,14 @@ class SliderTargetConfig: self.shuffle: bool = kwargs.get('shuffle', False) +class GuidanceConfig: + def __init__(self, **kwargs): + self.target_class: str = kwargs.get('target_class', '') + self.guidance_scale: float = kwargs.get('guidance_scale', 1.0) + self.positive_prompt: str = kwargs.get('positive_prompt', '') + self.negative_prompt: str = kwargs.get('negative_prompt', '') + + class SliderConfigAnchors: def __init__(self, **kwargs): self.prompt = kwargs.get('prompt', '') diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 504094d8..4d62b72d 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -2,6 +2,7 @@ import copy import json import os import random +import traceback from functools import lru_cache from typing import List, TYPE_CHECKING @@ -378,6 +379,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): ) self.file_list.append(file_item) except Exception as e: + print(traceback.format_exc()) print(f"Error processing image: {file}") print(e) bad_count += 1 diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index c6081467..f71add54 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -35,7 +35,6 @@ class FileItemDTO( ArgBreakMixin, ): def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) self.path = kwargs.get('path', None) self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) # process width and height @@ -48,6 +47,7 @@ class FileItemDTO( h, w = img.size self.width: int = w self.height: int = h + super().__init__(*args, **kwargs) # self.caption_path: str = kwargs.get('caption_path', None) self.raw_caption: str = kwargs.get('raw_caption', None) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 394dac91..d82d4144 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -779,11 +779,21 @@ class PoiFileItemDTOMixin: if self.poi not in json_data['poi']: raise Exception(f"Error: poi not found in caption file: {caption_path}") # poi has, x, y, width, height - poi = json_data['poi'][self.poi] - self.poi_x = int(poi['x']) - self.poi_y = int(poi['y']) - self.poi_width = int(poi['width']) - self.poi_height = int(poi['height']) + # do full image if no poi + self.poi_x = 0 + self.poi_y = 0 + self.poi_width = self.width + self.poi_height = self.height + try: + if self.poi in json_data['poi']: + poi = json_data['poi'][self.poi] + self.poi_x = int(poi['x']) + self.poi_y = int(poi['y']) + self.poi_width = int(poi['width']) + self.poi_height = int(poi['height']) + except Exception as e: + pass + # handle flipping if kwargs.get('flip_x', False): diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 891230b6..ca9aceff 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1059,7 +1059,7 @@ class StableDiffusion: refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') # load the refiner model dtype = get_torch_dtype(self.dtype) - model_path = self.model_config.refiner_name_or_path + model_path = self.model_config._original_refiner_name_or_path if not os.path.exists(model_path) or os.path.isdir(model_path): # TODO only load unet?? refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(