diffirential guidance is WORKING (from what I can tell)

This commit is contained in:
Jaret Burkett
2023-11-07 19:24:12 -07:00
parent dc8448d958
commit 1ee62562a4
7 changed files with 101 additions and 61 deletions

View File

@@ -4,6 +4,7 @@ from diffusers import T2IAdapter
from toolkit import train_tools
from toolkit.basic import value_map
from toolkit.config_modules import GuidanceConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import PromptEmbeds
@@ -32,7 +33,6 @@ class SDTrainer(BaseSDTrainProcess):
super().__init__(process_id, job, config, **kwargs)
self.assistant_adapter: Union['T2IAdapter', None]
self.do_prior_prediction = False
self.target_class = self.get_conf('target_class', '')
if self.train_config.inverted_mask_prior:
self.do_prior_prediction = True
@@ -187,84 +187,84 @@ class SDTrainer(BaseSDTrainProcess):
**kwargs
):
with torch.no_grad():
conditional_noisy_latents = noisy_latents
dtype = get_torch_dtype(self.train_config.dtype)
# target class is unconditional
target_class_embeds = self.sd.encode_prompt(self.target_class).detach()
if batch.unconditional_latents is not None:
# do the unconditional prediction here instead of a prior prediction
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(batch.unconditional_latents, noise,
timesteps)
# Encode the unconditional image into latents
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
batch.unconditional_latents, noise, timesteps
)
was_network_active = self.network.is_active
# was_network_active = self.network.is_active
self.network.is_active = False
self.sd.unet.eval()
guidance_scale = 1.0
# calculate the differential between our conditional (target image) and out unconditional ("bad" image)
target_differential = unconditional_noisy_latents - conditional_noisy_latents
target_differential = target_differential.detach()
def cfg(uncon, con):
return uncon + guidance_scale * (
con - uncon
)
# add the target differential to the target latents as if it were noise with the scheduler scaled to
# the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done
# properly
guidance_latents = self.sd.noise_scheduler.add_noise(
conditional_noisy_latents,
target_differential,
timesteps
)
target_conditional = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
# With LoRA network bypassed, predict noise to get a baseline of what the network
# wants to do with the latents + noise. Pass our target latents here for the input.
target_unconditional = self.sd.predict_noise(
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
target_unconditional = self.sd.predict_noise(
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=target_class_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
neutral_latents = (noisy_latents + unconditional_noisy_latents) / 2.0
target_noise = cfg(target_unconditional, target_conditional)
# latents = self.noise_scheduler.step(target_noise, timesteps, noisy_latents, return_dict=False)[0]
# target_pred = target_pred - noisy_latents + (unconditional_noisy_latents - noise)
# target_noise_res = noisy_latents - unconditional_noisy_latents
# target_pred = cfg(unconditional_noisy_latents, target_pred)
# target_pred = target_pred + target_noise_res
self.network.is_active = True
# turn the LoRA network back on.
self.sd.unet.train()
self.network.is_active = True
self.network.multiplier = network_weight_list
prediction = self.sd.predict_noise(
latents=neutral_latents.to(self.device_torch, dtype=dtype).detach(),
# with LoRA active, predict the noise with the scaled differential latents added. This will allow us
# the opportunity to predict the differential + noise that was added to the latents.
prediction_unconditional = self.sd.predict_noise(
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
# prediction_res = target_pred - prediction
# remove the baseline conditional prediction. This will leave only the divergence from the baseline and
# the prediction of the added differential noise
prediction_positive = prediction_unconditional - target_unconditional
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
# this is the diffusion training process.
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
# differential between the conditional and unconditional images
positive_loss = torch.nn.functional.mse_loss(
prediction_positive.float(),
target_differential.float(),
reduction="none"
)
positive_loss = positive_loss.mean([1, 2, 3])
# send it backwards BEFORE switching network polarity
positive_loss = self.apply_snr(positive_loss, timesteps)
positive_loss = positive_loss.mean()
positive_loss.backward()
# loss = positive_loss.detach() + negative_loss.detach()
loss = positive_loss.detach()
# prediction = cfg(prediction, target_pred)
# add a grad so other backward does not fail
loss.requires_grad_(True)
loss = torch.nn.functional.mse_loss(prediction.float(), target_noise.float(), reduction="none")
loss = loss.mean([1, 2, 3])
# restore network
self.network.multiplier = network_weight_list
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
return loss
def get_prior_prediction(