Added random weight adjuster to prevent overfitting

This commit is contained in:
Jaret Burkett
2023-07-29 19:30:14 -06:00
parent c35b78f0d4
commit c01673f1b5
2 changed files with 56 additions and 44 deletions

View File

@@ -3,6 +3,8 @@ import time
from collections import OrderedDict
import os
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
from toolkit.paths import REPOS_ROOT
@@ -383,6 +385,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
add_time_ids=None,
**kwargs,
):
if self.sd.is_xl:
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
@@ -407,25 +410,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
guided_target = noise_pred_uncond + guidance_scale * (
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
# noise_pred = rescale_noise_cfg(
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
# )
noise_pred = guided_target
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
else:
noise_pred = train_util.predict_noise(
self.sd.unet,
self.sd.noise_scheduler,
# if we are doing classifier free guidance, need to double up
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
# predict the noise residual
noise_pred = self.sd.unet(
latent_model_input,
timestep,
latents,
text_embeddings.text_embeds if hasattr(text_embeddings, 'text_embeds') else text_embeddings,
guidance_scale=guidance_scale
encoder_hidden_states=text_embeddings.text_embeds,
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred

View File

@@ -40,6 +40,7 @@ class EncodedPromptPair:
def __init__(
self,
target_class,
target_class_with_neutral,
positive_target,
positive_target_with_neutral,
negative_target,
@@ -52,6 +53,7 @@ class EncodedPromptPair:
weight=1.0
):
self.target_class = target_class
self.target_class_with_neutral = target_class_with_neutral
self.positive_target = positive_target
self.positive_target_with_neutral = positive_target_with_neutral
self.negative_target = negative_target
@@ -171,6 +173,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
for target in self.slider_config.targets:
prompt_list = [
f"{target.target_class}", # target_class
f"{target.target_class} {neutral}", # target_class with neutral
f"{target.positive}", # positive_target
f"{target.positive} {neutral}", # positive_target with neutral
f"{target.negative}", # negative_target
@@ -217,6 +220,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# erase standard
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.positive}"],
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
negative_target=cache[f"{target.negative}"],
@@ -234,6 +238,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# enhance standard, swap pos neg
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.negative}"],
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
negative_target=cache[f"{target.positive}"],
@@ -251,6 +256,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# erase inverted
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.negative}"],
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
negative_target=cache[f"{target.positive}"],
@@ -268,6 +274,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# enhance inverted
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.positive}"],
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
negative_target=cache[f"{target.negative}"],
@@ -299,28 +306,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
multiplier=anchor.multiplier
)
]
# self.print("Encoding complete. Building prompt pairs..")
# for neutral in self.prompt_txt_list:
# for target in self.slider_config.targets:
# both_prompts_list = [
# f"{target.positive} {target.negative}",
# f"{target.negative} {target.positive}",
# ]
# # randomly pick one of the both prompts to prevent bias
# both_prompts = both_prompts_list[torch.randint(0, 2, (1,)).item()]
#
# prompt_pair = EncodedPromptPair(
# positive_target=cache[f"{target.positive}"],
# positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
# negative_target=cache[f"{target.negative}"],
# negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
# neutral=cache[neutral],
# both_targets=cache[both_prompts],
# empty_prompt=cache[""],
# target_class=cache[f"{target.target_class}"],
# weight=target.weight,
# ).to(device="cpu", dtype=torch.float32)
# self.prompt_pairs.append(prompt_pair)
# move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling
@@ -340,7 +326,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
# get random multiplier between 1 and 3
rand_weight = torch.rand((1,)).item() * 2 + 1
rand_weight = 1
# rand_weight = torch.rand((1,)).item() * 2 + 1
# get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[
@@ -367,12 +354,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss()
def get_noise_pred(p, n, gs, cts, dn):
def get_noise_pred(neg, pos, gs, cts, dn):
return self.predict_noise(
latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings(
p, # negative prompt
n, # positive prompt
neg, # negative prompt
pos, # positive prompt
self.train_config.batch_size,
),
timestep=cts,
@@ -410,8 +397,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
denoised_latents = self.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
positive, # unconditional
target_class, # target
prompt_pair.positive_target, # unconditional
prompt_pair.target_class, # target
self.train_config.batch_size,
),
start_timesteps=0,
@@ -426,15 +413,27 @@ class TrainSliderProcess(BaseSDTrainProcess):
]
positive_latents = get_noise_pred(
positive, negative, 1, current_timestep, denoised_latents
prompt_pair.positive_target, # negative prompt
prompt_pair.negative_target, # positive prompt
1,
current_timestep,
denoised_latents
).to("cpu", dtype=torch.float32)
neutral_latents = get_noise_pred(
positive, neutral, 1, current_timestep, denoised_latents
prompt_pair.positive_target, # negative prompt
prompt_pair.empty_prompt, # positive prompt (normally neutral
1,
current_timestep,
denoised_latents
).to("cpu", dtype=torch.float32)
unconditional_latents = get_noise_pred(
positive, positive, 1, current_timestep, denoised_latents
prompt_pair.positive_target, # negative prompt
prompt_pair.positive_target, # positive prompt
1,
current_timestep,
denoised_latents
).to("cpu", dtype=torch.float32)
anchor_loss = None
@@ -461,7 +460,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network:
self.network.multiplier = prompt_pair.multiplier * rand_weight
target_latents = get_noise_pred(
positive, target_class, 1, current_timestep, denoised_latents
prompt_pair.positive_target,
prompt_pair.target_class,
1,
current_timestep,
denoised_latents
).to("cpu", dtype=torch.float32)
# if self.logging_config.verbose: