Reworked the balancing and swapping of the lora during training to make it much more stable when trained

This commit is contained in:
Jaret Burkett
2023-07-22 14:13:39 -06:00
parent ce8e7a1271
commit 3f4f429c4a

View File

@@ -3,7 +3,7 @@
import time import time
from collections import OrderedDict from collections import OrderedDict
import os import os
from typing import List from typing import List, Literal
from toolkit.kohya_model_util import load_vae from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork from toolkit.lora_special import LoRASpecialNetwork
@@ -25,8 +25,12 @@ from tqdm import tqdm
from toolkit.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV, TRAINING_METHODS from toolkit.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV, TRAINING_METHODS
from leco import train_util, model_util from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache, PromptEmbedsPair, ACTION_TYPES from leco.prompt_util import PromptEmbedsCache
from leco import debug_util
class ACTION_TYPES_SLIDER:
ERASE_NEGATIVE = 0
ENHANCE_NEGATIVE = 1
def flush(): def flush():
@@ -104,9 +108,10 @@ class ModelConfig:
class SliderTargetConfig: class SliderTargetConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.target_class: str = kwargs.get('target_class', None) self.target_class: str = kwargs.get('target_class', '')
self.positive: str = kwargs.get('positive', None) self.positive: str = kwargs.get('positive', None)
self.negative: str = kwargs.get('negative', None) self.negative: str = kwargs.get('negative', None)
self.multiplier: float = kwargs.get('multiplier', 1.0)
class SliderConfig: class SliderConfig:
@@ -117,20 +122,6 @@ class SliderConfig:
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]]) self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
class PromptSettingsOld:
def __init__(self, **kwargs):
self.target: str = kwargs.get('target', None)
self.positive = kwargs.get('positive', None) # if None, target will be used
self.unconditional = kwargs.get('unconditional', "") # default is ""
self.neutral = kwargs.get('neutral', None) # if None, unconditional will be used
self.action: ACTION_TYPES = kwargs.get('action', "erase") # default is "erase"
self.guidance_scale: float = kwargs.get('guidance_scale', 1.0) # default is 1.0
self.resolution: int = kwargs.get('resolution', 512) # default is 512
self.dynamic_resolution: bool = kwargs.get('dynamic_resolution', False) # default is False
self.batch_size: int = kwargs.get('batch_size', 1) # default is 1
self.dynamic_crops: bool = kwargs.get('dynamic_crops', False) # default is False. only used when model is XL
class EncodedPromptPair: class EncodedPromptPair:
def __init__( def __init__(
self, self,
@@ -139,7 +130,9 @@ class EncodedPromptPair:
negative, negative,
neutral, neutral,
width=512, width=512,
height=512 height=512,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=1.0,
): ):
self.target_class = target_class self.target_class = target_class
self.positive = positive self.positive = positive
@@ -147,6 +140,8 @@ class EncodedPromptPair:
self.neutral = neutral self.neutral = neutral
self.width = width self.width = width
self.height = height self.height = height
self.action: int = action
self.multiplier = multiplier
class TrainSliderProcess(BaseTrainProcess): class TrainSliderProcess(BaseTrainProcess):
@@ -299,7 +294,7 @@ class TrainSliderProcess(BaseTrainProcess):
def get_training_info(self): def get_training_info(self):
info = OrderedDict({ info = OrderedDict({
'step': self.step_num 'step': self.step_num + 1
}) })
return info return info
@@ -395,7 +390,7 @@ class TrainSliderProcess(BaseTrainProcess):
loss_function = torch.nn.MSELoss() loss_function = torch.nn.MSELoss()
cache = PromptEmbedsCache() cache = PromptEmbedsCache()
prompt_pairs: list[LatentPair] = [] prompt_pairs: list[EncodedPromptPair] = []
# get encoded latents for our prompts # get encoded latents for our prompts
with torch.no_grad(): with torch.no_grad():
@@ -403,6 +398,7 @@ class TrainSliderProcess(BaseTrainProcess):
for target in self.slider_config.targets: for target in self.slider_config.targets:
for resolution in self.slider_config.resolutions: for resolution in self.slider_config.resolutions:
width, height = resolution width, height = resolution
# build the cache
for prompt in [ for prompt in [
target.target_class, target.target_class,
target.positive, target.positive,
@@ -414,7 +410,13 @@ class TrainSliderProcess(BaseTrainProcess):
tokenizer, text_encoder, [prompt] tokenizer, text_encoder, [prompt]
) )
prompt_pairs.append( # for slider we need to have an enhancer, an eraser, and then
# an inverse with negative weights to balance the network
# if we don't do this, we will get different contrast and focus.
# we only perform actions of enhancing and erasing on the negative
# todo work on way to do all of this in one shot
prompt_pairs += [
# erase standard
EncodedPromptPair( EncodedPromptPair(
target_class=cache[target.target_class], target_class=cache[target.target_class],
positive=cache[target.positive], positive=cache[target.positive],
@@ -422,8 +424,43 @@ class TrainSliderProcess(BaseTrainProcess):
neutral=cache[neutral], neutral=cache[neutral],
width=width, width=width,
height=height, height=height,
) action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
) multiplier=target.multiplier
),
# erase inverted
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.negative],
negative=cache[target.positive],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier * -1.0
),
# enhance standard, swap pos neg
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.negative],
negative=cache[target.positive],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier
),
# enhance inverted
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.positive],
negative=cache[target.negative],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier * -1.0
),
]
# move to cpu to save vram # move to cpu to save vram
# tokenizer.to("cpu") # tokenizer.to("cpu")
@@ -449,20 +486,13 @@ class TrainSliderProcess(BaseTrainProcess):
height = prompt_pair.height height = prompt_pair.height
width = prompt_pair.width width = prompt_pair.width
positive = prompt_pair.positive
target_class = prompt_pair.target_class target_class = prompt_pair.target_class
neutral = prompt_pair.neutral neutral = prompt_pair.neutral
negative = prompt_pair.negative negative = prompt_pair.negative
positive = prompt_pair.positive
# swap every other step and invert lora to spread slider # set network multiplier
do_swap = step % 2 == 0 self.network.multiplier = prompt_pair.multiplier
if do_swap:
negative = prompt_pair.positive
positive = prompt_pair.negative
# set the network in a negative weight
self.network.multiplier = -1.0
with torch.no_grad(): with torch.no_grad():
noise_scheduler.set_timesteps( noise_scheduler.set_timesteps(
@@ -492,8 +522,8 @@ class TrainSliderProcess(BaseTrainProcess):
noise_scheduler, noise_scheduler,
latents, # pass simple noise latents latents, # pass simple noise latents
train_util.concat_embeddings( train_util.concat_embeddings(
positive, # unconditional positive, # unconditional
target_class, # target target_class, # target
self.train_config.batch_size, self.train_config.batch_size,
), ),
start_timesteps=0, start_timesteps=0,
@@ -526,7 +556,7 @@ class TrainSliderProcess(BaseTrainProcess):
current_timestep, current_timestep,
denoised_latents, denoised_latents,
train_util.concat_embeddings( train_util.concat_embeddings(
positive, # unconditional positive, # unconditional
neutral, # neutral neutral, # neutral
self.train_config.batch_size, self.train_config.batch_size,
), ),
@@ -553,7 +583,7 @@ class TrainSliderProcess(BaseTrainProcess):
denoised_latents, denoised_latents,
train_util.concat_embeddings( train_util.concat_embeddings(
positive, # unconditional positive, # unconditional
target_class, # target target_class, # target
self.train_config.batch_size, self.train_config.batch_size,
), ),
guidance_scale=1, guidance_scale=1,
@@ -566,7 +596,7 @@ class TrainSliderProcess(BaseTrainProcess):
neutral_latents.requires_grad = False neutral_latents.requires_grad = False
unconditional_latents.requires_grad = False unconditional_latents.requires_grad = False
erase = True erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE
guidance_scale = 1.0 guidance_scale = 1.0
offset = guidance_scale * (positive_latents - unconditional_latents) offset = guidance_scale * (positive_latents - unconditional_latents)
@@ -643,11 +673,10 @@ class TrainSliderProcess(BaseTrainProcess):
# end of step # end of step
self.step_num = step self.step_num = step
self.sample(self.step_num)
print("") print("")
self.save() self.save()
del ( del (
unet, unet,
noise_scheduler, noise_scheduler,