mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added random weight adjuster to prevent overfitting
This commit is contained in:
@@ -3,6 +3,8 @@ import time
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||||
|
|
||||||
from toolkit.lora_special import LoRASpecialNetwork
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
from toolkit.optimizer import get_optimizer
|
from toolkit.optimizer import get_optimizer
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
@@ -383,6 +385,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
add_time_ids=None,
|
add_time_ids=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
if self.sd.is_xl:
|
if self.sd.is_xl:
|
||||||
if add_time_ids is None:
|
if add_time_ids is None:
|
||||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||||
@@ -407,25 +410,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
guided_target = noise_pred_uncond + guidance_scale * (
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
noise_pred_text - noise_pred_uncond
|
noise_pred_text - noise_pred_uncond
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||||
# noise_pred = rescale_noise_cfg(
|
if guidance_rescale > 0.0:
|
||||||
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
# )
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
noise_pred = guided_target
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
noise_pred = train_util.predict_noise(
|
# if we are doing classifier free guidance, need to double up
|
||||||
self.sd.unet,
|
latent_model_input = torch.cat([latents] * 2)
|
||||||
self.sd.noise_scheduler,
|
|
||||||
|
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
noise_pred = self.sd.unet(
|
||||||
|
latent_model_input,
|
||||||
timestep,
|
timestep,
|
||||||
latents,
|
encoder_hidden_states=text_embeddings.text_embeds,
|
||||||
text_embeddings.text_embeds if hasattr(text_embeddings, 'text_embeds') else text_embeddings,
|
).sample
|
||||||
guidance_scale=guidance_scale
|
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond
|
||||||
)
|
)
|
||||||
|
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ class EncodedPromptPair:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target_class,
|
target_class,
|
||||||
|
target_class_with_neutral,
|
||||||
positive_target,
|
positive_target,
|
||||||
positive_target_with_neutral,
|
positive_target_with_neutral,
|
||||||
negative_target,
|
negative_target,
|
||||||
@@ -52,6 +53,7 @@ class EncodedPromptPair:
|
|||||||
weight=1.0
|
weight=1.0
|
||||||
):
|
):
|
||||||
self.target_class = target_class
|
self.target_class = target_class
|
||||||
|
self.target_class_with_neutral = target_class_with_neutral
|
||||||
self.positive_target = positive_target
|
self.positive_target = positive_target
|
||||||
self.positive_target_with_neutral = positive_target_with_neutral
|
self.positive_target_with_neutral = positive_target_with_neutral
|
||||||
self.negative_target = negative_target
|
self.negative_target = negative_target
|
||||||
@@ -171,6 +173,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
for target in self.slider_config.targets:
|
for target in self.slider_config.targets:
|
||||||
prompt_list = [
|
prompt_list = [
|
||||||
f"{target.target_class}", # target_class
|
f"{target.target_class}", # target_class
|
||||||
|
f"{target.target_class} {neutral}", # target_class with neutral
|
||||||
f"{target.positive}", # positive_target
|
f"{target.positive}", # positive_target
|
||||||
f"{target.positive} {neutral}", # positive_target with neutral
|
f"{target.positive} {neutral}", # positive_target with neutral
|
||||||
f"{target.negative}", # negative_target
|
f"{target.negative}", # negative_target
|
||||||
@@ -217,6 +220,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
# erase standard
|
# erase standard
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
target_class=cache[target.target_class],
|
target_class=cache[target.target_class],
|
||||||
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
||||||
positive_target=cache[f"{target.positive}"],
|
positive_target=cache[f"{target.positive}"],
|
||||||
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
||||||
negative_target=cache[f"{target.negative}"],
|
negative_target=cache[f"{target.negative}"],
|
||||||
@@ -234,6 +238,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
# enhance standard, swap pos neg
|
# enhance standard, swap pos neg
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
target_class=cache[target.target_class],
|
target_class=cache[target.target_class],
|
||||||
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
||||||
positive_target=cache[f"{target.negative}"],
|
positive_target=cache[f"{target.negative}"],
|
||||||
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
||||||
negative_target=cache[f"{target.positive}"],
|
negative_target=cache[f"{target.positive}"],
|
||||||
@@ -251,6 +256,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
# erase inverted
|
# erase inverted
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
target_class=cache[target.target_class],
|
target_class=cache[target.target_class],
|
||||||
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
||||||
positive_target=cache[f"{target.negative}"],
|
positive_target=cache[f"{target.negative}"],
|
||||||
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
||||||
negative_target=cache[f"{target.positive}"],
|
negative_target=cache[f"{target.positive}"],
|
||||||
@@ -268,6 +274,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
# enhance inverted
|
# enhance inverted
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
target_class=cache[target.target_class],
|
target_class=cache[target.target_class],
|
||||||
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
||||||
positive_target=cache[f"{target.positive}"],
|
positive_target=cache[f"{target.positive}"],
|
||||||
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
||||||
negative_target=cache[f"{target.negative}"],
|
negative_target=cache[f"{target.negative}"],
|
||||||
@@ -299,28 +306,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
multiplier=anchor.multiplier
|
multiplier=anchor.multiplier
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
# self.print("Encoding complete. Building prompt pairs..")
|
|
||||||
# for neutral in self.prompt_txt_list:
|
|
||||||
# for target in self.slider_config.targets:
|
|
||||||
# both_prompts_list = [
|
|
||||||
# f"{target.positive} {target.negative}",
|
|
||||||
# f"{target.negative} {target.positive}",
|
|
||||||
# ]
|
|
||||||
# # randomly pick one of the both prompts to prevent bias
|
|
||||||
# both_prompts = both_prompts_list[torch.randint(0, 2, (1,)).item()]
|
|
||||||
#
|
|
||||||
# prompt_pair = EncodedPromptPair(
|
|
||||||
# positive_target=cache[f"{target.positive}"],
|
|
||||||
# positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
|
||||||
# negative_target=cache[f"{target.negative}"],
|
|
||||||
# negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
|
||||||
# neutral=cache[neutral],
|
|
||||||
# both_targets=cache[both_prompts],
|
|
||||||
# empty_prompt=cache[""],
|
|
||||||
# target_class=cache[f"{target.target_class}"],
|
|
||||||
# weight=target.weight,
|
|
||||||
# ).to(device="cpu", dtype=torch.float32)
|
|
||||||
# self.prompt_pairs.append(prompt_pair)
|
|
||||||
|
|
||||||
# move to cpu to save vram
|
# move to cpu to save vram
|
||||||
# We don't need text encoder anymore, but keep it on cpu for sampling
|
# We don't need text encoder anymore, but keep it on cpu for sampling
|
||||||
@@ -340,7 +326,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
|
||||||
# get random multiplier between 1 and 3
|
# get random multiplier between 1 and 3
|
||||||
rand_weight = torch.rand((1,)).item() * 2 + 1
|
rand_weight = 1
|
||||||
|
# rand_weight = torch.rand((1,)).item() * 2 + 1
|
||||||
|
|
||||||
# get a random pair
|
# get a random pair
|
||||||
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
||||||
@@ -367,12 +354,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
lr_scheduler = self.lr_scheduler
|
lr_scheduler = self.lr_scheduler
|
||||||
loss_function = torch.nn.MSELoss()
|
loss_function = torch.nn.MSELoss()
|
||||||
|
|
||||||
def get_noise_pred(p, n, gs, cts, dn):
|
def get_noise_pred(neg, pos, gs, cts, dn):
|
||||||
return self.predict_noise(
|
return self.predict_noise(
|
||||||
latents=dn,
|
latents=dn,
|
||||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||||
p, # negative prompt
|
neg, # negative prompt
|
||||||
n, # positive prompt
|
pos, # positive prompt
|
||||||
self.train_config.batch_size,
|
self.train_config.batch_size,
|
||||||
),
|
),
|
||||||
timestep=cts,
|
timestep=cts,
|
||||||
@@ -410,8 +397,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
denoised_latents = self.diffuse_some_steps(
|
denoised_latents = self.diffuse_some_steps(
|
||||||
latents, # pass simple noise latents
|
latents, # pass simple noise latents
|
||||||
train_tools.concat_prompt_embeddings(
|
train_tools.concat_prompt_embeddings(
|
||||||
positive, # unconditional
|
prompt_pair.positive_target, # unconditional
|
||||||
target_class, # target
|
prompt_pair.target_class, # target
|
||||||
self.train_config.batch_size,
|
self.train_config.batch_size,
|
||||||
),
|
),
|
||||||
start_timesteps=0,
|
start_timesteps=0,
|
||||||
@@ -426,15 +413,27 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
]
|
]
|
||||||
|
|
||||||
positive_latents = get_noise_pred(
|
positive_latents = get_noise_pred(
|
||||||
positive, negative, 1, current_timestep, denoised_latents
|
prompt_pair.positive_target, # negative prompt
|
||||||
|
prompt_pair.negative_target, # positive prompt
|
||||||
|
1,
|
||||||
|
current_timestep,
|
||||||
|
denoised_latents
|
||||||
).to("cpu", dtype=torch.float32)
|
).to("cpu", dtype=torch.float32)
|
||||||
|
|
||||||
neutral_latents = get_noise_pred(
|
neutral_latents = get_noise_pred(
|
||||||
positive, neutral, 1, current_timestep, denoised_latents
|
prompt_pair.positive_target, # negative prompt
|
||||||
|
prompt_pair.empty_prompt, # positive prompt (normally neutral
|
||||||
|
1,
|
||||||
|
current_timestep,
|
||||||
|
denoised_latents
|
||||||
).to("cpu", dtype=torch.float32)
|
).to("cpu", dtype=torch.float32)
|
||||||
|
|
||||||
unconditional_latents = get_noise_pred(
|
unconditional_latents = get_noise_pred(
|
||||||
positive, positive, 1, current_timestep, denoised_latents
|
prompt_pair.positive_target, # negative prompt
|
||||||
|
prompt_pair.positive_target, # positive prompt
|
||||||
|
1,
|
||||||
|
current_timestep,
|
||||||
|
denoised_latents
|
||||||
).to("cpu", dtype=torch.float32)
|
).to("cpu", dtype=torch.float32)
|
||||||
|
|
||||||
anchor_loss = None
|
anchor_loss = None
|
||||||
@@ -461,7 +460,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
with self.network:
|
with self.network:
|
||||||
self.network.multiplier = prompt_pair.multiplier * rand_weight
|
self.network.multiplier = prompt_pair.multiplier * rand_weight
|
||||||
target_latents = get_noise_pred(
|
target_latents = get_noise_pred(
|
||||||
positive, target_class, 1, current_timestep, denoised_latents
|
prompt_pair.positive_target,
|
||||||
|
prompt_pair.target_class,
|
||||||
|
1,
|
||||||
|
current_timestep,
|
||||||
|
denoised_latents
|
||||||
).to("cpu", dtype=torch.float32)
|
).to("cpu", dtype=torch.float32)
|
||||||
|
|
||||||
# if self.logging_config.verbose:
|
# if self.logging_config.verbose:
|
||||||
|
|||||||
Reference in New Issue
Block a user