mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added base for using guidance during training. Still not working right.
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
from diffusers import T2IAdapter
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
@@ -30,6 +32,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
self.assistant_adapter: Union['T2IAdapter', None]
|
||||
self.do_prior_prediction = False
|
||||
self.target_class = self.get_conf('target_class', '')
|
||||
if self.train_config.inverted_mask_prior:
|
||||
self.do_prior_prediction = True
|
||||
|
||||
@@ -171,6 +174,99 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
return batch
|
||||
|
||||
def get_guided_loss(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# target class is unconditional
|
||||
target_class_embeds = self.sd.encode_prompt(self.target_class).detach()
|
||||
|
||||
if batch.unconditional_latents is not None:
|
||||
# do the unconditional prediction here instead of a prior prediction
|
||||
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(batch.unconditional_latents, noise,
|
||||
timesteps)
|
||||
|
||||
was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
|
||||
guidance_scale = 1.0
|
||||
|
||||
def cfg(uncon, con):
|
||||
return uncon + guidance_scale * (
|
||||
con - uncon
|
||||
)
|
||||
|
||||
target_conditional = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
|
||||
target_unconditional = self.sd.predict_noise(
|
||||
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=target_class_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
|
||||
neutral_latents = (noisy_latents + unconditional_noisy_latents) / 2.0
|
||||
|
||||
target_noise = cfg(target_unconditional, target_conditional)
|
||||
# latents = self.noise_scheduler.step(target_noise, timesteps, noisy_latents, return_dict=False)[0]
|
||||
|
||||
# target_pred = target_pred - noisy_latents + (unconditional_noisy_latents - noise)
|
||||
|
||||
# target_noise_res = noisy_latents - unconditional_noisy_latents
|
||||
|
||||
# target_pred = cfg(unconditional_noisy_latents, target_pred)
|
||||
# target_pred = target_pred + target_noise_res
|
||||
|
||||
self.network.is_active = True
|
||||
self.sd.unet.train()
|
||||
|
||||
prediction = self.sd.predict_noise(
|
||||
latents=neutral_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
# prediction_res = target_pred - prediction
|
||||
|
||||
|
||||
# prediction = cfg(prediction, target_pred)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(prediction.float(), target_noise.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.learnable_snr_gos:
|
||||
# add snr_gamma
|
||||
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
||||
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001:
|
||||
# add snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
||||
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
def get_prior_prediction(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
@@ -369,8 +465,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
else:
|
||||
prompt_2_list = [prompts_2]
|
||||
|
||||
|
||||
|
||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
|
||||
noisy_latents_list,
|
||||
noise_list,
|
||||
@@ -386,8 +480,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with self.timer('encode_prompt'):
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
|
||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
|
||||
long_prompts=True).to(
|
||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
else:
|
||||
@@ -398,8 +493,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
|
||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
|
||||
long_prompts=True).to(
|
||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
@@ -450,27 +546,42 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
|
||||
self.before_unet_predict()
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
if batch.unconditional_latents is not None:
|
||||
# do guided loss
|
||||
loss = self.get_guided_loss(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
else:
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
)
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError("loss is nan")
|
||||
|
||||
Reference in New Issue
Block a user