Reworked targeted guidance algo

This commit is contained in:
Jaret Burkett
2023-12-01 06:30:52 -07:00
parent bd2bce9b92
commit 92cb5ae096

View File

@@ -189,15 +189,31 @@ class SDTrainer(BaseSDTrainProcess):
):
with torch.no_grad():
# Perform targeted guidance (working title)
conditional_noisy_latents = noisy_latents.detach() # target images
dtype = get_torch_dtype(self.train_config.dtype)
if batch.unconditional_latents is not None:
# unconditional latents are the "neutral" images. Add noise here identical to
# the noise added to the conditional latents, at the same timesteps
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
unconditional_noisy_latents = self.sd.add_noise(batch.unconditional_latents, noise, timesteps).detach()
unconditional_diff = unconditional_latents - conditional_latents
conditional_diff = conditional_latents - unconditional_latents
# we need to determine the amount of signal and noise that would be present at the current timestep
conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
target_noise = noise + unconditional_signal
conditional_noisy_latents = self.sd.add_noise(
unconditional_latents + conditional_signal,
target_noise,
timesteps
).detach()
unconditional_noisy_latents = self.sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
# Disable the LoRA network so we can predict parent network knowledge without it
self.network.is_active = False
@@ -227,9 +243,12 @@ class SDTrainer(BaseSDTrainProcess):
**pred_kwargs # adapter residuals in here
)
# remove baseline from our prediction to extract our differential prediction
prediction = prediction - baseline_prediction
loss = torch.nn.functional.mse_loss(
prediction.float(),
baseline_prediction.float(),
unconditional_signal.float(),
reduction="none"
)
loss = loss.mean([1, 2, 3])
@@ -303,7 +322,6 @@ class SDTrainer(BaseSDTrainProcess):
match_adapter_assist = False
# check if we are matching the adapter assistant
if self.assistant_adapter:
if self.train_config.match_adapter_chance == 1.0:
@@ -321,7 +339,6 @@ class SDTrainer(BaseSDTrainProcess):
if file_item.is_reg:
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
adapter_images = None
sigmas = None
if has_adapter_img and (self.adapter or self.assistant_adapter):