mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 06:29:48 +00:00
diffirential guidance is WORKING (from what I can tell)
This commit is contained in:
@@ -4,6 +4,7 @@ from diffusers import T2IAdapter
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
@@ -32,7 +33,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
self.assistant_adapter: Union['T2IAdapter', None]
|
||||
self.do_prior_prediction = False
|
||||
self.target_class = self.get_conf('target_class', '')
|
||||
if self.train_config.inverted_mask_prior:
|
||||
self.do_prior_prediction = True
|
||||
|
||||
@@ -187,84 +187,84 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
conditional_noisy_latents = noisy_latents
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# target class is unconditional
|
||||
target_class_embeds = self.sd.encode_prompt(self.target_class).detach()
|
||||
|
||||
if batch.unconditional_latents is not None:
|
||||
# do the unconditional prediction here instead of a prior prediction
|
||||
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(batch.unconditional_latents, noise,
|
||||
timesteps)
|
||||
# Encode the unconditional image into latents
|
||||
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
|
||||
batch.unconditional_latents, noise, timesteps
|
||||
)
|
||||
|
||||
was_network_active = self.network.is_active
|
||||
# was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
|
||||
guidance_scale = 1.0
|
||||
# calculate the differential between our conditional (target image) and out unconditional ("bad" image)
|
||||
target_differential = unconditional_noisy_latents - conditional_noisy_latents
|
||||
target_differential = target_differential.detach()
|
||||
|
||||
def cfg(uncon, con):
|
||||
return uncon + guidance_scale * (
|
||||
con - uncon
|
||||
)
|
||||
# add the target differential to the target latents as if it were noise with the scheduler scaled to
|
||||
# the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done
|
||||
# properly
|
||||
guidance_latents = self.sd.noise_scheduler.add_noise(
|
||||
conditional_noisy_latents,
|
||||
target_differential,
|
||||
timesteps
|
||||
)
|
||||
|
||||
target_conditional = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
# With LoRA network bypassed, predict noise to get a baseline of what the network
|
||||
# wants to do with the latents + noise. Pass our target latents here for the input.
|
||||
target_unconditional = self.sd.predict_noise(
|
||||
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
|
||||
target_unconditional = self.sd.predict_noise(
|
||||
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=target_class_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
|
||||
neutral_latents = (noisy_latents + unconditional_noisy_latents) / 2.0
|
||||
|
||||
target_noise = cfg(target_unconditional, target_conditional)
|
||||
# latents = self.noise_scheduler.step(target_noise, timesteps, noisy_latents, return_dict=False)[0]
|
||||
|
||||
# target_pred = target_pred - noisy_latents + (unconditional_noisy_latents - noise)
|
||||
|
||||
# target_noise_res = noisy_latents - unconditional_noisy_latents
|
||||
|
||||
# target_pred = cfg(unconditional_noisy_latents, target_pred)
|
||||
# target_pred = target_pred + target_noise_res
|
||||
|
||||
self.network.is_active = True
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
self.network.multiplier = network_weight_list
|
||||
|
||||
prediction = self.sd.predict_noise(
|
||||
latents=neutral_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
# with LoRA active, predict the noise with the scaled differential latents added. This will allow us
|
||||
# the opportunity to predict the differential + noise that was added to the latents.
|
||||
prediction_unconditional = self.sd.predict_noise(
|
||||
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
# prediction_res = target_pred - prediction
|
||||
# remove the baseline conditional prediction. This will leave only the divergence from the baseline and
|
||||
# the prediction of the added differential noise
|
||||
prediction_positive = prediction_unconditional - target_unconditional
|
||||
|
||||
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
|
||||
# this is the diffusion training process.
|
||||
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
|
||||
# differential between the conditional and unconditional images
|
||||
positive_loss = torch.nn.functional.mse_loss(
|
||||
prediction_positive.float(),
|
||||
target_differential.float(),
|
||||
reduction="none"
|
||||
)
|
||||
positive_loss = positive_loss.mean([1, 2, 3])
|
||||
# send it backwards BEFORE switching network polarity
|
||||
positive_loss = self.apply_snr(positive_loss, timesteps)
|
||||
positive_loss = positive_loss.mean()
|
||||
positive_loss.backward()
|
||||
# loss = positive_loss.detach() + negative_loss.detach()
|
||||
loss = positive_loss.detach()
|
||||
|
||||
# prediction = cfg(prediction, target_pred)
|
||||
# add a grad so other backward does not fail
|
||||
loss.requires_grad_(True)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(prediction.float(), target_noise.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
# restore network
|
||||
self.network.multiplier = network_weight_list
|
||||
|
||||
if self.train_config.learnable_snr_gos:
|
||||
# add snr_gamma
|
||||
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
||||
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001:
|
||||
# add snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
||||
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
def get_prior_prediction(
|
||||
|
||||
@@ -40,13 +40,13 @@ from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta, \
|
||||
parse_metadata_from_safetensors
|
||||
from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma
|
||||
from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma, apply_learnable_snr_gos, apply_snr_weight
|
||||
import gc
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -94,6 +94,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.data_loader_reg: Union[DataLoader, None] = None
|
||||
self.trigger_word = self.get_conf('trigger_word', None)
|
||||
|
||||
self.guidance_config: Union[GuidanceConfig, None] = None
|
||||
guidance_config_raw = self.get_conf('guidance', None)
|
||||
if guidance_config_raw is not None:
|
||||
self.guidance_config = GuidanceConfig(**guidance_config_raw)
|
||||
|
||||
# store is all are cached. Allows us to not load vae if we don't need to
|
||||
self.is_latents_cached = True
|
||||
raw_datasets = self.get_conf('datasets', None)
|
||||
@@ -417,6 +422,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.print(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
self.post_save_hook(file_path)
|
||||
flush()
|
||||
|
||||
# Called before the model is loaded
|
||||
def hook_before_model_load(self):
|
||||
@@ -501,6 +507,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
print("load_weights not implemented for non-network models")
|
||||
return None
|
||||
|
||||
def apply_snr(self, seperated_loss, timesteps):
|
||||
if self.train_config.learnable_snr_gos:
|
||||
# add snr_gamma
|
||||
seperated_loss = apply_learnable_snr_gos(seperated_loss, timesteps, self.snr_gos)
|
||||
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001:
|
||||
# add snr_gamma
|
||||
seperated_loss = apply_snr_weight(seperated_loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
||||
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
seperated_loss = apply_snr_weight(seperated_loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
return seperated_loss
|
||||
|
||||
def load_lorm(self):
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
if latest_save_path is not None:
|
||||
|
||||
@@ -242,6 +242,7 @@ class ModelConfig:
|
||||
self.dtype: str = kwargs.get('dtype', 'float16')
|
||||
self.vae_path = kwargs.get('vae_path', None)
|
||||
self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None)
|
||||
self._original_refiner_name_or_path = self.refiner_name_or_path
|
||||
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
|
||||
|
||||
# only for SDXL models for now
|
||||
@@ -286,6 +287,14 @@ class SliderTargetConfig:
|
||||
self.shuffle: bool = kwargs.get('shuffle', False)
|
||||
|
||||
|
||||
class GuidanceConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.target_class: str = kwargs.get('target_class', '')
|
||||
self.guidance_scale: float = kwargs.get('guidance_scale', 1.0)
|
||||
self.positive_prompt: str = kwargs.get('positive_prompt', '')
|
||||
self.negative_prompt: str = kwargs.get('negative_prompt', '')
|
||||
|
||||
|
||||
class SliderConfigAnchors:
|
||||
def __init__(self, **kwargs):
|
||||
self.prompt = kwargs.get('prompt', '')
|
||||
|
||||
@@ -2,6 +2,7 @@ import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
from functools import lru_cache
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
@@ -378,6 +379,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
)
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f"Error processing image: {file}")
|
||||
print(e)
|
||||
bad_count += 1
|
||||
|
||||
@@ -35,7 +35,6 @@ class FileItemDTO(
|
||||
ArgBreakMixin,
|
||||
):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.path = kwargs.get('path', None)
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
# process width and height
|
||||
@@ -48,6 +47,7 @@ class FileItemDTO(
|
||||
h, w = img.size
|
||||
self.width: int = w
|
||||
self.height: int = h
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# self.caption_path: str = kwargs.get('caption_path', None)
|
||||
self.raw_caption: str = kwargs.get('raw_caption', None)
|
||||
|
||||
@@ -779,11 +779,21 @@ class PoiFileItemDTOMixin:
|
||||
if self.poi not in json_data['poi']:
|
||||
raise Exception(f"Error: poi not found in caption file: {caption_path}")
|
||||
# poi has, x, y, width, height
|
||||
poi = json_data['poi'][self.poi]
|
||||
self.poi_x = int(poi['x'])
|
||||
self.poi_y = int(poi['y'])
|
||||
self.poi_width = int(poi['width'])
|
||||
self.poi_height = int(poi['height'])
|
||||
# do full image if no poi
|
||||
self.poi_x = 0
|
||||
self.poi_y = 0
|
||||
self.poi_width = self.width
|
||||
self.poi_height = self.height
|
||||
try:
|
||||
if self.poi in json_data['poi']:
|
||||
poi = json_data['poi'][self.poi]
|
||||
self.poi_x = int(poi['x'])
|
||||
self.poi_y = int(poi['y'])
|
||||
self.poi_width = int(poi['width'])
|
||||
self.poi_height = int(poi['height'])
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
# handle flipping
|
||||
if kwargs.get('flip_x', False):
|
||||
|
||||
@@ -1059,7 +1059,7 @@ class StableDiffusion:
|
||||
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
|
||||
# load the refiner model
|
||||
dtype = get_torch_dtype(self.dtype)
|
||||
model_path = self.model_config.refiner_name_or_path
|
||||
model_path = self.model_config._original_refiner_name_or_path
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
# TODO only load unet??
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user