mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Reworked the balancing and swapping of the lora during training to make it much more stable when trained
This commit is contained in:
@@ -3,7 +3,7 @@
|
|||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List, Literal
|
||||||
|
|
||||||
from toolkit.kohya_model_util import load_vae
|
from toolkit.kohya_model_util import load_vae
|
||||||
from toolkit.lora_special import LoRASpecialNetwork
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
@@ -25,8 +25,12 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from toolkit.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV, TRAINING_METHODS
|
from toolkit.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV, TRAINING_METHODS
|
||||||
from leco import train_util, model_util
|
from leco import train_util, model_util
|
||||||
from leco.prompt_util import PromptEmbedsCache, PromptEmbedsPair, ACTION_TYPES
|
from leco.prompt_util import PromptEmbedsCache
|
||||||
from leco import debug_util
|
|
||||||
|
|
||||||
|
class ACTION_TYPES_SLIDER:
|
||||||
|
ERASE_NEGATIVE = 0
|
||||||
|
ENHANCE_NEGATIVE = 1
|
||||||
|
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
@@ -104,9 +108,10 @@ class ModelConfig:
|
|||||||
|
|
||||||
class SliderTargetConfig:
|
class SliderTargetConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.target_class: str = kwargs.get('target_class', None)
|
self.target_class: str = kwargs.get('target_class', '')
|
||||||
self.positive: str = kwargs.get('positive', None)
|
self.positive: str = kwargs.get('positive', None)
|
||||||
self.negative: str = kwargs.get('negative', None)
|
self.negative: str = kwargs.get('negative', None)
|
||||||
|
self.multiplier: float = kwargs.get('multiplier', 1.0)
|
||||||
|
|
||||||
|
|
||||||
class SliderConfig:
|
class SliderConfig:
|
||||||
@@ -117,20 +122,6 @@ class SliderConfig:
|
|||||||
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
|
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
|
||||||
|
|
||||||
|
|
||||||
class PromptSettingsOld:
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.target: str = kwargs.get('target', None)
|
|
||||||
self.positive = kwargs.get('positive', None) # if None, target will be used
|
|
||||||
self.unconditional = kwargs.get('unconditional', "") # default is ""
|
|
||||||
self.neutral = kwargs.get('neutral', None) # if None, unconditional will be used
|
|
||||||
self.action: ACTION_TYPES = kwargs.get('action', "erase") # default is "erase"
|
|
||||||
self.guidance_scale: float = kwargs.get('guidance_scale', 1.0) # default is 1.0
|
|
||||||
self.resolution: int = kwargs.get('resolution', 512) # default is 512
|
|
||||||
self.dynamic_resolution: bool = kwargs.get('dynamic_resolution', False) # default is False
|
|
||||||
self.batch_size: int = kwargs.get('batch_size', 1) # default is 1
|
|
||||||
self.dynamic_crops: bool = kwargs.get('dynamic_crops', False) # default is False. only used when model is XL
|
|
||||||
|
|
||||||
|
|
||||||
class EncodedPromptPair:
|
class EncodedPromptPair:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -139,7 +130,9 @@ class EncodedPromptPair:
|
|||||||
negative,
|
negative,
|
||||||
neutral,
|
neutral,
|
||||||
width=512,
|
width=512,
|
||||||
height=512
|
height=512,
|
||||||
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||||
|
multiplier=1.0,
|
||||||
):
|
):
|
||||||
self.target_class = target_class
|
self.target_class = target_class
|
||||||
self.positive = positive
|
self.positive = positive
|
||||||
@@ -147,6 +140,8 @@ class EncodedPromptPair:
|
|||||||
self.neutral = neutral
|
self.neutral = neutral
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
|
self.action: int = action
|
||||||
|
self.multiplier = multiplier
|
||||||
|
|
||||||
|
|
||||||
class TrainSliderProcess(BaseTrainProcess):
|
class TrainSliderProcess(BaseTrainProcess):
|
||||||
@@ -299,7 +294,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
def get_training_info(self):
|
def get_training_info(self):
|
||||||
info = OrderedDict({
|
info = OrderedDict({
|
||||||
'step': self.step_num
|
'step': self.step_num + 1
|
||||||
})
|
})
|
||||||
return info
|
return info
|
||||||
|
|
||||||
@@ -395,7 +390,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
loss_function = torch.nn.MSELoss()
|
loss_function = torch.nn.MSELoss()
|
||||||
|
|
||||||
cache = PromptEmbedsCache()
|
cache = PromptEmbedsCache()
|
||||||
prompt_pairs: list[LatentPair] = []
|
prompt_pairs: list[EncodedPromptPair] = []
|
||||||
|
|
||||||
# get encoded latents for our prompts
|
# get encoded latents for our prompts
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -403,6 +398,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
for target in self.slider_config.targets:
|
for target in self.slider_config.targets:
|
||||||
for resolution in self.slider_config.resolutions:
|
for resolution in self.slider_config.resolutions:
|
||||||
width, height = resolution
|
width, height = resolution
|
||||||
|
# build the cache
|
||||||
for prompt in [
|
for prompt in [
|
||||||
target.target_class,
|
target.target_class,
|
||||||
target.positive,
|
target.positive,
|
||||||
@@ -414,7 +410,13 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
tokenizer, text_encoder, [prompt]
|
tokenizer, text_encoder, [prompt]
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_pairs.append(
|
# for slider we need to have an enhancer, an eraser, and then
|
||||||
|
# an inverse with negative weights to balance the network
|
||||||
|
# if we don't do this, we will get different contrast and focus.
|
||||||
|
# we only perform actions of enhancing and erasing on the negative
|
||||||
|
# todo work on way to do all of this in one shot
|
||||||
|
prompt_pairs += [
|
||||||
|
# erase standard
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
target_class=cache[target.target_class],
|
target_class=cache[target.target_class],
|
||||||
positive=cache[target.positive],
|
positive=cache[target.positive],
|
||||||
@@ -422,8 +424,43 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
neutral=cache[neutral],
|
neutral=cache[neutral],
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
)
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||||
)
|
multiplier=target.multiplier
|
||||||
|
),
|
||||||
|
# erase inverted
|
||||||
|
EncodedPromptPair(
|
||||||
|
target_class=cache[target.target_class],
|
||||||
|
positive=cache[target.negative],
|
||||||
|
negative=cache[target.positive],
|
||||||
|
neutral=cache[neutral],
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||||
|
multiplier=target.multiplier * -1.0
|
||||||
|
),
|
||||||
|
# enhance standard, swap pos neg
|
||||||
|
EncodedPromptPair(
|
||||||
|
target_class=cache[target.target_class],
|
||||||
|
positive=cache[target.negative],
|
||||||
|
negative=cache[target.positive],
|
||||||
|
neutral=cache[neutral],
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||||
|
multiplier=target.multiplier
|
||||||
|
),
|
||||||
|
# enhance inverted
|
||||||
|
EncodedPromptPair(
|
||||||
|
target_class=cache[target.target_class],
|
||||||
|
positive=cache[target.positive],
|
||||||
|
negative=cache[target.negative],
|
||||||
|
neutral=cache[neutral],
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||||
|
multiplier=target.multiplier * -1.0
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
# move to cpu to save vram
|
# move to cpu to save vram
|
||||||
# tokenizer.to("cpu")
|
# tokenizer.to("cpu")
|
||||||
@@ -449,20 +486,13 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
height = prompt_pair.height
|
height = prompt_pair.height
|
||||||
width = prompt_pair.width
|
width = prompt_pair.width
|
||||||
positive = prompt_pair.positive
|
|
||||||
target_class = prompt_pair.target_class
|
target_class = prompt_pair.target_class
|
||||||
neutral = prompt_pair.neutral
|
neutral = prompt_pair.neutral
|
||||||
negative = prompt_pair.negative
|
negative = prompt_pair.negative
|
||||||
|
positive = prompt_pair.positive
|
||||||
|
|
||||||
# swap every other step and invert lora to spread slider
|
# set network multiplier
|
||||||
do_swap = step % 2 == 0
|
self.network.multiplier = prompt_pair.multiplier
|
||||||
|
|
||||||
if do_swap:
|
|
||||||
negative = prompt_pair.positive
|
|
||||||
positive = prompt_pair.negative
|
|
||||||
# set the network in a negative weight
|
|
||||||
self.network.multiplier = -1.0
|
|
||||||
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
noise_scheduler.set_timesteps(
|
noise_scheduler.set_timesteps(
|
||||||
@@ -492,8 +522,8 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
noise_scheduler,
|
noise_scheduler,
|
||||||
latents, # pass simple noise latents
|
latents, # pass simple noise latents
|
||||||
train_util.concat_embeddings(
|
train_util.concat_embeddings(
|
||||||
positive, # unconditional
|
positive, # unconditional
|
||||||
target_class, # target
|
target_class, # target
|
||||||
self.train_config.batch_size,
|
self.train_config.batch_size,
|
||||||
),
|
),
|
||||||
start_timesteps=0,
|
start_timesteps=0,
|
||||||
@@ -526,7 +556,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
current_timestep,
|
current_timestep,
|
||||||
denoised_latents,
|
denoised_latents,
|
||||||
train_util.concat_embeddings(
|
train_util.concat_embeddings(
|
||||||
positive, # unconditional
|
positive, # unconditional
|
||||||
neutral, # neutral
|
neutral, # neutral
|
||||||
self.train_config.batch_size,
|
self.train_config.batch_size,
|
||||||
),
|
),
|
||||||
@@ -553,7 +583,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
denoised_latents,
|
denoised_latents,
|
||||||
train_util.concat_embeddings(
|
train_util.concat_embeddings(
|
||||||
positive, # unconditional
|
positive, # unconditional
|
||||||
target_class, # target
|
target_class, # target
|
||||||
self.train_config.batch_size,
|
self.train_config.batch_size,
|
||||||
),
|
),
|
||||||
guidance_scale=1,
|
guidance_scale=1,
|
||||||
@@ -566,7 +596,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
neutral_latents.requires_grad = False
|
neutral_latents.requires_grad = False
|
||||||
unconditional_latents.requires_grad = False
|
unconditional_latents.requires_grad = False
|
||||||
|
|
||||||
erase = True
|
erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE
|
||||||
guidance_scale = 1.0
|
guidance_scale = 1.0
|
||||||
|
|
||||||
offset = guidance_scale * (positive_latents - unconditional_latents)
|
offset = guidance_scale * (positive_latents - unconditional_latents)
|
||||||
@@ -643,11 +673,10 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
# end of step
|
# end of step
|
||||||
self.step_num = step
|
self.step_num = step
|
||||||
|
|
||||||
|
self.sample(self.step_num)
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
|
|
||||||
del (
|
del (
|
||||||
unet,
|
unet,
|
||||||
noise_scheduler,
|
noise_scheduler,
|
||||||
|
|||||||
Reference in New Issue
Block a user