mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Reworked the balancing and swapping of the lora during training to make it much more stable when trained
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Literal
|
||||
|
||||
from toolkit.kohya_model_util import load_vae
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
@@ -25,8 +25,12 @@ from tqdm import tqdm
|
||||
|
||||
from toolkit.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV, TRAINING_METHODS
|
||||
from leco import train_util, model_util
|
||||
from leco.prompt_util import PromptEmbedsCache, PromptEmbedsPair, ACTION_TYPES
|
||||
from leco import debug_util
|
||||
from leco.prompt_util import PromptEmbedsCache
|
||||
|
||||
|
||||
class ACTION_TYPES_SLIDER:
|
||||
ERASE_NEGATIVE = 0
|
||||
ENHANCE_NEGATIVE = 1
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -104,9 +108,10 @@ class ModelConfig:
|
||||
|
||||
class SliderTargetConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.target_class: str = kwargs.get('target_class', None)
|
||||
self.target_class: str = kwargs.get('target_class', '')
|
||||
self.positive: str = kwargs.get('positive', None)
|
||||
self.negative: str = kwargs.get('negative', None)
|
||||
self.multiplier: float = kwargs.get('multiplier', 1.0)
|
||||
|
||||
|
||||
class SliderConfig:
|
||||
@@ -117,20 +122,6 @@ class SliderConfig:
|
||||
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
|
||||
|
||||
|
||||
class PromptSettingsOld:
|
||||
def __init__(self, **kwargs):
|
||||
self.target: str = kwargs.get('target', None)
|
||||
self.positive = kwargs.get('positive', None) # if None, target will be used
|
||||
self.unconditional = kwargs.get('unconditional', "") # default is ""
|
||||
self.neutral = kwargs.get('neutral', None) # if None, unconditional will be used
|
||||
self.action: ACTION_TYPES = kwargs.get('action', "erase") # default is "erase"
|
||||
self.guidance_scale: float = kwargs.get('guidance_scale', 1.0) # default is 1.0
|
||||
self.resolution: int = kwargs.get('resolution', 512) # default is 512
|
||||
self.dynamic_resolution: bool = kwargs.get('dynamic_resolution', False) # default is False
|
||||
self.batch_size: int = kwargs.get('batch_size', 1) # default is 1
|
||||
self.dynamic_crops: bool = kwargs.get('dynamic_crops', False) # default is False. only used when model is XL
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -139,7 +130,9 @@ class EncodedPromptPair:
|
||||
negative,
|
||||
neutral,
|
||||
width=512,
|
||||
height=512
|
||||
height=512,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=1.0,
|
||||
):
|
||||
self.target_class = target_class
|
||||
self.positive = positive
|
||||
@@ -147,6 +140,8 @@ class EncodedPromptPair:
|
||||
self.neutral = neutral
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.action: int = action
|
||||
self.multiplier = multiplier
|
||||
|
||||
|
||||
class TrainSliderProcess(BaseTrainProcess):
|
||||
@@ -299,7 +294,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
|
||||
def get_training_info(self):
|
||||
info = OrderedDict({
|
||||
'step': self.step_num
|
||||
'step': self.step_num + 1
|
||||
})
|
||||
return info
|
||||
|
||||
@@ -395,7 +390,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
cache = PromptEmbedsCache()
|
||||
prompt_pairs: list[LatentPair] = []
|
||||
prompt_pairs: list[EncodedPromptPair] = []
|
||||
|
||||
# get encoded latents for our prompts
|
||||
with torch.no_grad():
|
||||
@@ -403,6 +398,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
for target in self.slider_config.targets:
|
||||
for resolution in self.slider_config.resolutions:
|
||||
width, height = resolution
|
||||
# build the cache
|
||||
for prompt in [
|
||||
target.target_class,
|
||||
target.positive,
|
||||
@@ -414,7 +410,13 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
tokenizer, text_encoder, [prompt]
|
||||
)
|
||||
|
||||
prompt_pairs.append(
|
||||
# for slider we need to have an enhancer, an eraser, and then
|
||||
# an inverse with negative weights to balance the network
|
||||
# if we don't do this, we will get different contrast and focus.
|
||||
# we only perform actions of enhancing and erasing on the negative
|
||||
# todo work on way to do all of this in one shot
|
||||
prompt_pairs += [
|
||||
# erase standard
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.positive],
|
||||
@@ -422,8 +424,43 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
)
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=target.multiplier
|
||||
),
|
||||
# erase inverted
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.negative],
|
||||
negative=cache[target.positive],
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=target.multiplier * -1.0
|
||||
),
|
||||
# enhance standard, swap pos neg
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.negative],
|
||||
negative=cache[target.positive],
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||
multiplier=target.multiplier
|
||||
),
|
||||
# enhance inverted
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.positive],
|
||||
negative=cache[target.negative],
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||
multiplier=target.multiplier * -1.0
|
||||
),
|
||||
]
|
||||
|
||||
# move to cpu to save vram
|
||||
# tokenizer.to("cpu")
|
||||
@@ -449,20 +486,13 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
|
||||
height = prompt_pair.height
|
||||
width = prompt_pair.width
|
||||
positive = prompt_pair.positive
|
||||
target_class = prompt_pair.target_class
|
||||
neutral = prompt_pair.neutral
|
||||
negative = prompt_pair.negative
|
||||
positive = prompt_pair.positive
|
||||
|
||||
# swap every other step and invert lora to spread slider
|
||||
do_swap = step % 2 == 0
|
||||
|
||||
if do_swap:
|
||||
negative = prompt_pair.positive
|
||||
positive = prompt_pair.negative
|
||||
# set the network in a negative weight
|
||||
self.network.multiplier = -1.0
|
||||
|
||||
# set network multiplier
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
|
||||
with torch.no_grad():
|
||||
noise_scheduler.set_timesteps(
|
||||
@@ -492,8 +522,8 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
noise_scheduler,
|
||||
latents, # pass simple noise latents
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
target_class, # target
|
||||
positive, # unconditional
|
||||
target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
start_timesteps=0,
|
||||
@@ -526,7 +556,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
positive, # unconditional
|
||||
neutral, # neutral
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
@@ -553,7 +583,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
target_class, # target
|
||||
target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
@@ -566,7 +596,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
neutral_latents.requires_grad = False
|
||||
unconditional_latents.requires_grad = False
|
||||
|
||||
erase = True
|
||||
erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE
|
||||
guidance_scale = 1.0
|
||||
|
||||
offset = guidance_scale * (positive_latents - unconditional_latents)
|
||||
@@ -643,11 +673,10 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
# end of step
|
||||
self.step_num = step
|
||||
|
||||
self.sample(self.step_num)
|
||||
print("")
|
||||
|
||||
self.save()
|
||||
|
||||
|
||||
del (
|
||||
unet,
|
||||
noise_scheduler,
|
||||
|
||||
Reference in New Issue
Block a user