mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added random weight adjuster to prevent overfitting
This commit is contained in:
@@ -3,6 +3,8 @@ import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
@@ -383,6 +385,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
add_time_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if self.sd.is_xl:
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
@@ -407,25 +410,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
guided_target = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||
# noise_pred = rescale_noise_cfg(
|
||||
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
||||
# )
|
||||
|
||||
noise_pred = guided_target
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
else:
|
||||
noise_pred = train_util.predict_noise(
|
||||
self.sd.unet,
|
||||
self.sd.noise_scheduler,
|
||||
# if we are doing classifier free guidance, need to double up
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.sd.unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
latents,
|
||||
text_embeddings.text_embeds if hasattr(text_embeddings, 'text_embeds') else text_embeddings,
|
||||
guidance_scale=guidance_scale
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
).sample
|
||||
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
return noise_pred
|
||||
|
||||
Reference in New Issue
Block a user