Added random weight adjuster to prevent overfitting

This commit is contained in:
Jaret Burkett
2023-07-29 19:30:14 -06:00
parent c35b78f0d4
commit c01673f1b5
2 changed files with 56 additions and 44 deletions

View File

@@ -3,6 +3,8 @@ import time
from collections import OrderedDict from collections import OrderedDict
import os import os
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from toolkit.lora_special import LoRASpecialNetwork from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer from toolkit.optimizer import get_optimizer
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
@@ -383,6 +385,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
add_time_ids=None, add_time_ids=None,
**kwargs, **kwargs,
): ):
if self.sd.is_xl: if self.sd.is_xl:
if add_time_ids is None: if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents) add_time_ids = self.get_time_ids_from_latents(latents)
@@ -407,25 +410,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
# perform guidance # perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
guided_target = noise_pred_uncond + guidance_scale * ( noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond noise_pred_text - noise_pred_uncond
) )
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
# noise_pred = rescale_noise_cfg( if guidance_rescale > 0.0:
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# ) noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
noise_pred = guided_target
else: else:
noise_pred = train_util.predict_noise( # if we are doing classifier free guidance, need to double up
self.sd.unet, latent_model_input = torch.cat([latents] * 2)
self.sd.noise_scheduler,
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
# predict the noise residual
noise_pred = self.sd.unet(
latent_model_input,
timestep, timestep,
latents, encoder_hidden_states=text_embeddings.text_embeds,
text_embeddings.text_embeds if hasattr(text_embeddings, 'text_embeds') else text_embeddings, ).sample
guidance_scale=guidance_scale
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
) )
return noise_pred return noise_pred

View File

@@ -40,6 +40,7 @@ class EncodedPromptPair:
def __init__( def __init__(
self, self,
target_class, target_class,
target_class_with_neutral,
positive_target, positive_target,
positive_target_with_neutral, positive_target_with_neutral,
negative_target, negative_target,
@@ -52,6 +53,7 @@ class EncodedPromptPair:
weight=1.0 weight=1.0
): ):
self.target_class = target_class self.target_class = target_class
self.target_class_with_neutral = target_class_with_neutral
self.positive_target = positive_target self.positive_target = positive_target
self.positive_target_with_neutral = positive_target_with_neutral self.positive_target_with_neutral = positive_target_with_neutral
self.negative_target = negative_target self.negative_target = negative_target
@@ -171,6 +173,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
for target in self.slider_config.targets: for target in self.slider_config.targets:
prompt_list = [ prompt_list = [
f"{target.target_class}", # target_class f"{target.target_class}", # target_class
f"{target.target_class} {neutral}", # target_class with neutral
f"{target.positive}", # positive_target f"{target.positive}", # positive_target
f"{target.positive} {neutral}", # positive_target with neutral f"{target.positive} {neutral}", # positive_target with neutral
f"{target.negative}", # negative_target f"{target.negative}", # negative_target
@@ -217,6 +220,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# erase standard # erase standard
EncodedPromptPair( EncodedPromptPair(
target_class=cache[target.target_class], target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.positive}"], positive_target=cache[f"{target.positive}"],
positive_target_with_neutral=cache[f"{target.positive} {neutral}"], positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
negative_target=cache[f"{target.negative}"], negative_target=cache[f"{target.negative}"],
@@ -234,6 +238,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# enhance standard, swap pos neg # enhance standard, swap pos neg
EncodedPromptPair( EncodedPromptPair(
target_class=cache[target.target_class], target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.negative}"], positive_target=cache[f"{target.negative}"],
positive_target_with_neutral=cache[f"{target.negative} {neutral}"], positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
negative_target=cache[f"{target.positive}"], negative_target=cache[f"{target.positive}"],
@@ -251,6 +256,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# erase inverted # erase inverted
EncodedPromptPair( EncodedPromptPair(
target_class=cache[target.target_class], target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.negative}"], positive_target=cache[f"{target.negative}"],
positive_target_with_neutral=cache[f"{target.negative} {neutral}"], positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
negative_target=cache[f"{target.positive}"], negative_target=cache[f"{target.positive}"],
@@ -268,6 +274,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# enhance inverted # enhance inverted
EncodedPromptPair( EncodedPromptPair(
target_class=cache[target.target_class], target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.positive}"], positive_target=cache[f"{target.positive}"],
positive_target_with_neutral=cache[f"{target.positive} {neutral}"], positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
negative_target=cache[f"{target.negative}"], negative_target=cache[f"{target.negative}"],
@@ -299,28 +306,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
multiplier=anchor.multiplier multiplier=anchor.multiplier
) )
] ]
# self.print("Encoding complete. Building prompt pairs..")
# for neutral in self.prompt_txt_list:
# for target in self.slider_config.targets:
# both_prompts_list = [
# f"{target.positive} {target.negative}",
# f"{target.negative} {target.positive}",
# ]
# # randomly pick one of the both prompts to prevent bias
# both_prompts = both_prompts_list[torch.randint(0, 2, (1,)).item()]
#
# prompt_pair = EncodedPromptPair(
# positive_target=cache[f"{target.positive}"],
# positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
# negative_target=cache[f"{target.negative}"],
# negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
# neutral=cache[neutral],
# both_targets=cache[both_prompts],
# empty_prompt=cache[""],
# target_class=cache[f"{target.target_class}"],
# weight=target.weight,
# ).to(device="cpu", dtype=torch.float32)
# self.prompt_pairs.append(prompt_pair)
# move to cpu to save vram # move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling # We don't need text encoder anymore, but keep it on cpu for sampling
@@ -340,7 +326,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
# get random multiplier between 1 and 3 # get random multiplier between 1 and 3
rand_weight = torch.rand((1,)).item() * 2 + 1 rand_weight = 1
# rand_weight = torch.rand((1,)).item() * 2 + 1
# get a random pair # get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[ prompt_pair: EncodedPromptPair = self.prompt_pairs[
@@ -367,12 +354,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
lr_scheduler = self.lr_scheduler lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss() loss_function = torch.nn.MSELoss()
def get_noise_pred(p, n, gs, cts, dn): def get_noise_pred(neg, pos, gs, cts, dn):
return self.predict_noise( return self.predict_noise(
latents=dn, latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings( text_embeddings=train_tools.concat_prompt_embeddings(
p, # negative prompt neg, # negative prompt
n, # positive prompt pos, # positive prompt
self.train_config.batch_size, self.train_config.batch_size,
), ),
timestep=cts, timestep=cts,
@@ -410,8 +397,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
denoised_latents = self.diffuse_some_steps( denoised_latents = self.diffuse_some_steps(
latents, # pass simple noise latents latents, # pass simple noise latents
train_tools.concat_prompt_embeddings( train_tools.concat_prompt_embeddings(
positive, # unconditional prompt_pair.positive_target, # unconditional
target_class, # target prompt_pair.target_class, # target
self.train_config.batch_size, self.train_config.batch_size,
), ),
start_timesteps=0, start_timesteps=0,
@@ -426,15 +413,27 @@ class TrainSliderProcess(BaseSDTrainProcess):
] ]
positive_latents = get_noise_pred( positive_latents = get_noise_pred(
positive, negative, 1, current_timestep, denoised_latents prompt_pair.positive_target, # negative prompt
prompt_pair.negative_target, # positive prompt
1,
current_timestep,
denoised_latents
).to("cpu", dtype=torch.float32) ).to("cpu", dtype=torch.float32)
neutral_latents = get_noise_pred( neutral_latents = get_noise_pred(
positive, neutral, 1, current_timestep, denoised_latents prompt_pair.positive_target, # negative prompt
prompt_pair.empty_prompt, # positive prompt (normally neutral
1,
current_timestep,
denoised_latents
).to("cpu", dtype=torch.float32) ).to("cpu", dtype=torch.float32)
unconditional_latents = get_noise_pred( unconditional_latents = get_noise_pred(
positive, positive, 1, current_timestep, denoised_latents prompt_pair.positive_target, # negative prompt
prompt_pair.positive_target, # positive prompt
1,
current_timestep,
denoised_latents
).to("cpu", dtype=torch.float32) ).to("cpu", dtype=torch.float32)
anchor_loss = None anchor_loss = None
@@ -461,7 +460,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network: with self.network:
self.network.multiplier = prompt_pair.multiplier * rand_weight self.network.multiplier = prompt_pair.multiplier * rand_weight
target_latents = get_noise_pred( target_latents = get_noise_pred(
positive, target_class, 1, current_timestep, denoised_latents prompt_pair.positive_target,
prompt_pair.target_class,
1,
current_timestep,
denoised_latents
).to("cpu", dtype=torch.float32) ).to("cpu", dtype=torch.float32)
# if self.logging_config.verbose: # if self.logging_config.verbose: