mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-22 21:33:59 +00:00
Added random weight adjuster to prevent overfitting
This commit is contained in:
@@ -3,6 +3,8 @@ import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
@@ -383,6 +385,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
add_time_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if self.sd.is_xl:
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
@@ -407,25 +410,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
guided_target = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||
# noise_pred = rescale_noise_cfg(
|
||||
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
||||
# )
|
||||
|
||||
noise_pred = guided_target
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
else:
|
||||
noise_pred = train_util.predict_noise(
|
||||
self.sd.unet,
|
||||
self.sd.noise_scheduler,
|
||||
# if we are doing classifier free guidance, need to double up
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.sd.unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
latents,
|
||||
text_embeddings.text_embeds if hasattr(text_embeddings, 'text_embeds') else text_embeddings,
|
||||
guidance_scale=guidance_scale
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
).sample
|
||||
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
return noise_pred
|
||||
|
||||
@@ -40,6 +40,7 @@ class EncodedPromptPair:
|
||||
def __init__(
|
||||
self,
|
||||
target_class,
|
||||
target_class_with_neutral,
|
||||
positive_target,
|
||||
positive_target_with_neutral,
|
||||
negative_target,
|
||||
@@ -52,6 +53,7 @@ class EncodedPromptPair:
|
||||
weight=1.0
|
||||
):
|
||||
self.target_class = target_class
|
||||
self.target_class_with_neutral = target_class_with_neutral
|
||||
self.positive_target = positive_target
|
||||
self.positive_target_with_neutral = positive_target_with_neutral
|
||||
self.negative_target = negative_target
|
||||
@@ -171,6 +173,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
for target in self.slider_config.targets:
|
||||
prompt_list = [
|
||||
f"{target.target_class}", # target_class
|
||||
f"{target.target_class} {neutral}", # target_class with neutral
|
||||
f"{target.positive}", # positive_target
|
||||
f"{target.positive} {neutral}", # positive_target with neutral
|
||||
f"{target.negative}", # negative_target
|
||||
@@ -217,6 +220,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# erase standard
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
||||
positive_target=cache[f"{target.positive}"],
|
||||
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
||||
negative_target=cache[f"{target.negative}"],
|
||||
@@ -234,6 +238,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# enhance standard, swap pos neg
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
||||
positive_target=cache[f"{target.negative}"],
|
||||
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
||||
negative_target=cache[f"{target.positive}"],
|
||||
@@ -251,6 +256,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# erase inverted
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
||||
positive_target=cache[f"{target.negative}"],
|
||||
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
||||
negative_target=cache[f"{target.positive}"],
|
||||
@@ -268,6 +274,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# enhance inverted
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
||||
positive_target=cache[f"{target.positive}"],
|
||||
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
||||
negative_target=cache[f"{target.negative}"],
|
||||
@@ -299,28 +306,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
multiplier=anchor.multiplier
|
||||
)
|
||||
]
|
||||
# self.print("Encoding complete. Building prompt pairs..")
|
||||
# for neutral in self.prompt_txt_list:
|
||||
# for target in self.slider_config.targets:
|
||||
# both_prompts_list = [
|
||||
# f"{target.positive} {target.negative}",
|
||||
# f"{target.negative} {target.positive}",
|
||||
# ]
|
||||
# # randomly pick one of the both prompts to prevent bias
|
||||
# both_prompts = both_prompts_list[torch.randint(0, 2, (1,)).item()]
|
||||
#
|
||||
# prompt_pair = EncodedPromptPair(
|
||||
# positive_target=cache[f"{target.positive}"],
|
||||
# positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
||||
# negative_target=cache[f"{target.negative}"],
|
||||
# negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
||||
# neutral=cache[neutral],
|
||||
# both_targets=cache[both_prompts],
|
||||
# empty_prompt=cache[""],
|
||||
# target_class=cache[f"{target.target_class}"],
|
||||
# weight=target.weight,
|
||||
# ).to(device="cpu", dtype=torch.float32)
|
||||
# self.prompt_pairs.append(prompt_pair)
|
||||
|
||||
|
||||
# move to cpu to save vram
|
||||
# We don't need text encoder anymore, but keep it on cpu for sampling
|
||||
@@ -340,7 +326,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# get random multiplier between 1 and 3
|
||||
rand_weight = torch.rand((1,)).item() * 2 + 1
|
||||
rand_weight = 1
|
||||
# rand_weight = torch.rand((1,)).item() * 2 + 1
|
||||
|
||||
# get a random pair
|
||||
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
||||
@@ -367,12 +354,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
lr_scheduler = self.lr_scheduler
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
def get_noise_pred(p, n, gs, cts, dn):
|
||||
def get_noise_pred(neg, pos, gs, cts, dn):
|
||||
return self.predict_noise(
|
||||
latents=dn,
|
||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||
p, # negative prompt
|
||||
n, # positive prompt
|
||||
neg, # negative prompt
|
||||
pos, # positive prompt
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
timestep=cts,
|
||||
@@ -410,8 +397,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
denoised_latents = self.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
positive, # unconditional
|
||||
target_class, # target
|
||||
prompt_pair.positive_target, # unconditional
|
||||
prompt_pair.target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
start_timesteps=0,
|
||||
@@ -426,15 +413,27 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
]
|
||||
|
||||
positive_latents = get_noise_pred(
|
||||
positive, negative, 1, current_timestep, denoised_latents
|
||||
prompt_pair.positive_target, # negative prompt
|
||||
prompt_pair.negative_target, # positive prompt
|
||||
1,
|
||||
current_timestep,
|
||||
denoised_latents
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
neutral_latents = get_noise_pred(
|
||||
positive, neutral, 1, current_timestep, denoised_latents
|
||||
prompt_pair.positive_target, # negative prompt
|
||||
prompt_pair.empty_prompt, # positive prompt (normally neutral
|
||||
1,
|
||||
current_timestep,
|
||||
denoised_latents
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
unconditional_latents = get_noise_pred(
|
||||
positive, positive, 1, current_timestep, denoised_latents
|
||||
prompt_pair.positive_target, # negative prompt
|
||||
prompt_pair.positive_target, # positive prompt
|
||||
1,
|
||||
current_timestep,
|
||||
denoised_latents
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
anchor_loss = None
|
||||
@@ -461,7 +460,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
with self.network:
|
||||
self.network.multiplier = prompt_pair.multiplier * rand_weight
|
||||
target_latents = get_noise_pred(
|
||||
positive, target_class, 1, current_timestep, denoised_latents
|
||||
prompt_pair.positive_target,
|
||||
prompt_pair.target_class,
|
||||
1,
|
||||
current_timestep,
|
||||
denoised_latents
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
# if self.logging_config.verbose:
|
||||
|
||||
Reference in New Issue
Block a user