Reworked the balancing and swapping of the lora during training to make it much more stable when trained

This commit is contained in:
Jaret Burkett
2023-07-22 14:13:39 -06:00
parent ce8e7a1271
commit 3f4f429c4a

View File

@@ -3,7 +3,7 @@
import time
from collections import OrderedDict
import os
from typing import List
from typing import List, Literal
from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork
@@ -25,8 +25,12 @@ from tqdm import tqdm
from toolkit.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV, TRAINING_METHODS
from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache, PromptEmbedsPair, ACTION_TYPES
from leco import debug_util
from leco.prompt_util import PromptEmbedsCache
class ACTION_TYPES_SLIDER:
ERASE_NEGATIVE = 0
ENHANCE_NEGATIVE = 1
def flush():
@@ -104,9 +108,10 @@ class ModelConfig:
class SliderTargetConfig:
def __init__(self, **kwargs):
self.target_class: str = kwargs.get('target_class', None)
self.target_class: str = kwargs.get('target_class', '')
self.positive: str = kwargs.get('positive', None)
self.negative: str = kwargs.get('negative', None)
self.multiplier: float = kwargs.get('multiplier', 1.0)
class SliderConfig:
@@ -117,20 +122,6 @@ class SliderConfig:
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
class PromptSettingsOld:
def __init__(self, **kwargs):
self.target: str = kwargs.get('target', None)
self.positive = kwargs.get('positive', None) # if None, target will be used
self.unconditional = kwargs.get('unconditional', "") # default is ""
self.neutral = kwargs.get('neutral', None) # if None, unconditional will be used
self.action: ACTION_TYPES = kwargs.get('action', "erase") # default is "erase"
self.guidance_scale: float = kwargs.get('guidance_scale', 1.0) # default is 1.0
self.resolution: int = kwargs.get('resolution', 512) # default is 512
self.dynamic_resolution: bool = kwargs.get('dynamic_resolution', False) # default is False
self.batch_size: int = kwargs.get('batch_size', 1) # default is 1
self.dynamic_crops: bool = kwargs.get('dynamic_crops', False) # default is False. only used when model is XL
class EncodedPromptPair:
def __init__(
self,
@@ -139,7 +130,9 @@ class EncodedPromptPair:
negative,
neutral,
width=512,
height=512
height=512,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=1.0,
):
self.target_class = target_class
self.positive = positive
@@ -147,6 +140,8 @@ class EncodedPromptPair:
self.neutral = neutral
self.width = width
self.height = height
self.action: int = action
self.multiplier = multiplier
class TrainSliderProcess(BaseTrainProcess):
@@ -299,7 +294,7 @@ class TrainSliderProcess(BaseTrainProcess):
def get_training_info(self):
info = OrderedDict({
'step': self.step_num
'step': self.step_num + 1
})
return info
@@ -395,7 +390,7 @@ class TrainSliderProcess(BaseTrainProcess):
loss_function = torch.nn.MSELoss()
cache = PromptEmbedsCache()
prompt_pairs: list[LatentPair] = []
prompt_pairs: list[EncodedPromptPair] = []
# get encoded latents for our prompts
with torch.no_grad():
@@ -403,6 +398,7 @@ class TrainSliderProcess(BaseTrainProcess):
for target in self.slider_config.targets:
for resolution in self.slider_config.resolutions:
width, height = resolution
# build the cache
for prompt in [
target.target_class,
target.positive,
@@ -414,7 +410,13 @@ class TrainSliderProcess(BaseTrainProcess):
tokenizer, text_encoder, [prompt]
)
prompt_pairs.append(
# for slider we need to have an enhancer, an eraser, and then
# an inverse with negative weights to balance the network
# if we don't do this, we will get different contrast and focus.
# we only perform actions of enhancing and erasing on the negative
# todo work on way to do all of this in one shot
prompt_pairs += [
# erase standard
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.positive],
@@ -422,8 +424,43 @@ class TrainSliderProcess(BaseTrainProcess):
neutral=cache[neutral],
width=width,
height=height,
)
)
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier
),
# erase inverted
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.negative],
negative=cache[target.positive],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier * -1.0
),
# enhance standard, swap pos neg
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.negative],
negative=cache[target.positive],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier
),
# enhance inverted
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.positive],
negative=cache[target.negative],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier * -1.0
),
]
# move to cpu to save vram
# tokenizer.to("cpu")
@@ -449,20 +486,13 @@ class TrainSliderProcess(BaseTrainProcess):
height = prompt_pair.height
width = prompt_pair.width
positive = prompt_pair.positive
target_class = prompt_pair.target_class
neutral = prompt_pair.neutral
negative = prompt_pair.negative
positive = prompt_pair.positive
# swap every other step and invert lora to spread slider
do_swap = step % 2 == 0
if do_swap:
negative = prompt_pair.positive
positive = prompt_pair.negative
# set the network in a negative weight
self.network.multiplier = -1.0
# set network multiplier
self.network.multiplier = prompt_pair.multiplier
with torch.no_grad():
noise_scheduler.set_timesteps(
@@ -492,8 +522,8 @@ class TrainSliderProcess(BaseTrainProcess):
noise_scheduler,
latents, # pass simple noise latents
train_util.concat_embeddings(
positive, # unconditional
target_class, # target
positive, # unconditional
target_class, # target
self.train_config.batch_size,
),
start_timesteps=0,
@@ -526,7 +556,7 @@ class TrainSliderProcess(BaseTrainProcess):
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
positive, # unconditional
neutral, # neutral
self.train_config.batch_size,
),
@@ -553,7 +583,7 @@ class TrainSliderProcess(BaseTrainProcess):
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
target_class, # target
target_class, # target
self.train_config.batch_size,
),
guidance_scale=1,
@@ -566,7 +596,7 @@ class TrainSliderProcess(BaseTrainProcess):
neutral_latents.requires_grad = False
unconditional_latents.requires_grad = False
erase = True
erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE
guidance_scale = 1.0
offset = guidance_scale * (positive_latents - unconditional_latents)
@@ -643,11 +673,10 @@ class TrainSliderProcess(BaseTrainProcess):
# end of step
self.step_num = step
self.sample(self.step_num)
print("")
self.save()
del (
unet,
noise_scheduler,