From 3f4f429c4abb99e6ec14dadf226150dead502b05 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 22 Jul 2023 14:13:39 -0600 Subject: [PATCH] Reworked the balancing and swapping of the lora during training to make it much more stable when trained --- jobs/process/TrainSliderProcess.py | 111 ++++++++++++++++++----------- 1 file changed, 70 insertions(+), 41 deletions(-) diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 1fa6d13c..bc619556 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -3,7 +3,7 @@ import time from collections import OrderedDict import os -from typing import List +from typing import List, Literal from toolkit.kohya_model_util import load_vae from toolkit.lora_special import LoRASpecialNetwork @@ -25,8 +25,12 @@ from tqdm import tqdm from toolkit.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV, TRAINING_METHODS from leco import train_util, model_util -from leco.prompt_util import PromptEmbedsCache, PromptEmbedsPair, ACTION_TYPES -from leco import debug_util +from leco.prompt_util import PromptEmbedsCache + + +class ACTION_TYPES_SLIDER: + ERASE_NEGATIVE = 0 + ENHANCE_NEGATIVE = 1 def flush(): @@ -104,9 +108,10 @@ class ModelConfig: class SliderTargetConfig: def __init__(self, **kwargs): - self.target_class: str = kwargs.get('target_class', None) + self.target_class: str = kwargs.get('target_class', '') self.positive: str = kwargs.get('positive', None) self.negative: str = kwargs.get('negative', None) + self.multiplier: float = kwargs.get('multiplier', 1.0) class SliderConfig: @@ -117,20 +122,6 @@ class SliderConfig: self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]]) -class PromptSettingsOld: - def __init__(self, **kwargs): - self.target: str = kwargs.get('target', None) - self.positive = kwargs.get('positive', None) # if None, target will be used - self.unconditional = kwargs.get('unconditional', "") # default is "" - self.neutral = kwargs.get('neutral', None) # if None, unconditional will be used - self.action: ACTION_TYPES = kwargs.get('action', "erase") # default is "erase" - self.guidance_scale: float = kwargs.get('guidance_scale', 1.0) # default is 1.0 - self.resolution: int = kwargs.get('resolution', 512) # default is 512 - self.dynamic_resolution: bool = kwargs.get('dynamic_resolution', False) # default is False - self.batch_size: int = kwargs.get('batch_size', 1) # default is 1 - self.dynamic_crops: bool = kwargs.get('dynamic_crops', False) # default is False. only used when model is XL - - class EncodedPromptPair: def __init__( self, @@ -139,7 +130,9 @@ class EncodedPromptPair: negative, neutral, width=512, - height=512 + height=512, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=1.0, ): self.target_class = target_class self.positive = positive @@ -147,6 +140,8 @@ class EncodedPromptPair: self.neutral = neutral self.width = width self.height = height + self.action: int = action + self.multiplier = multiplier class TrainSliderProcess(BaseTrainProcess): @@ -299,7 +294,7 @@ class TrainSliderProcess(BaseTrainProcess): def get_training_info(self): info = OrderedDict({ - 'step': self.step_num + 'step': self.step_num + 1 }) return info @@ -395,7 +390,7 @@ class TrainSliderProcess(BaseTrainProcess): loss_function = torch.nn.MSELoss() cache = PromptEmbedsCache() - prompt_pairs: list[LatentPair] = [] + prompt_pairs: list[EncodedPromptPair] = [] # get encoded latents for our prompts with torch.no_grad(): @@ -403,6 +398,7 @@ class TrainSliderProcess(BaseTrainProcess): for target in self.slider_config.targets: for resolution in self.slider_config.resolutions: width, height = resolution + # build the cache for prompt in [ target.target_class, target.positive, @@ -414,7 +410,13 @@ class TrainSliderProcess(BaseTrainProcess): tokenizer, text_encoder, [prompt] ) - prompt_pairs.append( + # for slider we need to have an enhancer, an eraser, and then + # an inverse with negative weights to balance the network + # if we don't do this, we will get different contrast and focus. + # we only perform actions of enhancing and erasing on the negative + # todo work on way to do all of this in one shot + prompt_pairs += [ + # erase standard EncodedPromptPair( target_class=cache[target.target_class], positive=cache[target.positive], @@ -422,8 +424,43 @@ class TrainSliderProcess(BaseTrainProcess): neutral=cache[neutral], width=width, height=height, - ) - ) + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier + ), + # erase inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.negative], + negative=cache[target.positive], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier * -1.0 + ), + # enhance standard, swap pos neg + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.negative], + negative=cache[target.positive], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier + ), + # enhance inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.positive], + negative=cache[target.negative], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier * -1.0 + ), + ] # move to cpu to save vram # tokenizer.to("cpu") @@ -449,20 +486,13 @@ class TrainSliderProcess(BaseTrainProcess): height = prompt_pair.height width = prompt_pair.width - positive = prompt_pair.positive target_class = prompt_pair.target_class neutral = prompt_pair.neutral negative = prompt_pair.negative + positive = prompt_pair.positive - # swap every other step and invert lora to spread slider - do_swap = step % 2 == 0 - - if do_swap: - negative = prompt_pair.positive - positive = prompt_pair.negative - # set the network in a negative weight - self.network.multiplier = -1.0 - + # set network multiplier + self.network.multiplier = prompt_pair.multiplier with torch.no_grad(): noise_scheduler.set_timesteps( @@ -492,8 +522,8 @@ class TrainSliderProcess(BaseTrainProcess): noise_scheduler, latents, # pass simple noise latents train_util.concat_embeddings( - positive, # unconditional - target_class, # target + positive, # unconditional + target_class, # target self.train_config.batch_size, ), start_timesteps=0, @@ -526,7 +556,7 @@ class TrainSliderProcess(BaseTrainProcess): current_timestep, denoised_latents, train_util.concat_embeddings( - positive, # unconditional + positive, # unconditional neutral, # neutral self.train_config.batch_size, ), @@ -553,7 +583,7 @@ class TrainSliderProcess(BaseTrainProcess): denoised_latents, train_util.concat_embeddings( positive, # unconditional - target_class, # target + target_class, # target self.train_config.batch_size, ), guidance_scale=1, @@ -566,7 +596,7 @@ class TrainSliderProcess(BaseTrainProcess): neutral_latents.requires_grad = False unconditional_latents.requires_grad = False - erase = True + erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE guidance_scale = 1.0 offset = guidance_scale * (positive_latents - unconditional_latents) @@ -643,11 +673,10 @@ class TrainSliderProcess(BaseTrainProcess): # end of step self.step_num = step + self.sample(self.step_num) print("") - self.save() - del ( unet, noise_scheduler,