Added random weight adjuster to prevent overfitting

This commit is contained in:
Jaret Burkett
2023-07-29 19:30:14 -06:00
parent c35b78f0d4
commit c01673f1b5
2 changed files with 56 additions and 44 deletions

View File

@@ -3,6 +3,8 @@ import time
from collections import OrderedDict
import os
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
from toolkit.paths import REPOS_ROOT
@@ -383,6 +385,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
add_time_ids=None,
**kwargs,
):
if self.sd.is_xl:
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
@@ -407,25 +410,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
guided_target = noise_pred_uncond + guidance_scale * (
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
# noise_pred = rescale_noise_cfg(
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
# )
noise_pred = guided_target
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
else:
noise_pred = train_util.predict_noise(
self.sd.unet,
self.sd.noise_scheduler,
# if we are doing classifier free guidance, need to double up
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
# predict the noise residual
noise_pred = self.sd.unet(
latent_model_input,
timestep,
latents,
text_embeddings.text_embeds if hasattr(text_embeddings, 'text_embeds') else text_embeddings,
guidance_scale=guidance_scale
encoder_hidden_states=text_embeddings.text_embeds,
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred