mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Reworked targeted guidance algo
This commit is contained in:
@@ -189,15 +189,31 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Perform targeted guidance (working title)
|
||||
conditional_noisy_latents = noisy_latents.detach() # target images
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
if batch.unconditional_latents is not None:
|
||||
# unconditional latents are the "neutral" images. Add noise here identical to
|
||||
# the noise added to the conditional latents, at the same timesteps
|
||||
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
unconditional_noisy_latents = self.sd.add_noise(batch.unconditional_latents, noise, timesteps).detach()
|
||||
unconditional_diff = unconditional_latents - conditional_latents
|
||||
conditional_diff = conditional_latents - unconditional_latents
|
||||
|
||||
# we need to determine the amount of signal and noise that would be present at the current timestep
|
||||
conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
||||
unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
||||
|
||||
target_noise = noise + unconditional_signal
|
||||
|
||||
conditional_noisy_latents = self.sd.add_noise(
|
||||
unconditional_latents + conditional_signal,
|
||||
target_noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
unconditional_noisy_latents = self.sd.add_noise(
|
||||
unconditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
self.network.is_active = False
|
||||
@@ -227,9 +243,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
# remove baseline from our prediction to extract our differential prediction
|
||||
prediction = prediction - baseline_prediction
|
||||
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
prediction.float(),
|
||||
baseline_prediction.float(),
|
||||
unconditional_signal.float(),
|
||||
reduction="none"
|
||||
)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
@@ -303,7 +322,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
|
||||
# check if we are matching the adapter assistant
|
||||
if self.assistant_adapter:
|
||||
if self.train_config.match_adapter_chance == 1.0:
|
||||
@@ -321,7 +339,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if file_item.is_reg:
|
||||
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
|
||||
|
||||
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
if has_adapter_img and (self.adapter or self.assistant_adapter):
|
||||
|
||||
Reference in New Issue
Block a user