Added blank prompt preservation

This commit is contained in:
Jaret Burkett
2025-10-22 14:55:13 -06:00
parent ca57ffc270
commit ee206cfa18
7 changed files with 143 additions and 28 deletions

View File

@@ -95,8 +95,13 @@ class SDTrainer(BaseSDTrainProcess):
raise ValueError("diff_output_preservation requires a network to be set")
if self.train_config.train_text_encoder:
raise ValueError("diff_output_preservation is not supported with train_text_encoder")
# always do a prior prediction when doing diff output preservation
if self.train_config.blank_prompt_preservation:
if self.network_config is None:
raise ValueError("blank_prompt_preservation requires a network to be set")
if self.train_config.blank_prompt_preservation or self.train_config.diff_output_preservation:
# always do a prior prediction when doing output preservation
self.do_prior_prediction = True
# store the loss target for a batch so we can use it in a loss
@@ -343,6 +348,13 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.text_encoder_to("cpu")
flush()
if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None:
# make sure we have this if not unloading
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs).to(
self.device_torch,
dtype=self.sd.torch_dtype
).detach()
if self.train_config.diffusion_feature_extractor_path is not None:
vae = self.sd.vae
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
@@ -1769,6 +1781,14 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.diff_output_preservation:
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
if self.train_config.blank_prompt_preservation:
blank_embeds = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
prior_embeds_to_use = concat_prompt_embeds(
[blank_embeds] * noisy_latents.shape[0]
)
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
conditional_embeds=prior_embeds_to_use,
@@ -1944,7 +1964,8 @@ class SDTrainer(BaseSDTrainProcess):
prior_to_calculate_loss = prior_pred
# if we are doing diff_output_preservation and not noing inverted masked prior
# then we need to send none here so it will not target the prior
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
doing_preservation = self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation
if doing_preservation and not do_inverted_masked_prior:
prior_to_calculate_loss = None
loss = self.calculate_loss(
@@ -1957,24 +1978,34 @@ class SDTrainer(BaseSDTrainProcess):
prior_pred=prior_to_calculate_loss,
)
if self.train_config.diff_output_preservation:
if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation:
# send the loss backwards otherwise checkpointing will fail
self.accelerator.backward(loss)
normal_loss = loss.detach() # dont send backward again
dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
dop_pred = self.predict_noise(
with torch.no_grad():
if self.train_config.diff_output_preservation:
preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
elif self.train_config.blank_prompt_preservation:
blank_embeds = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
preservation_embeds = concat_prompt_embeds(
[blank_embeds] * noisy_latents.shape[0]
)
preservation_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype),
conditional_embeds=preservation_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier
self.accelerator.backward(dop_loss)
loss = normal_loss + dop_loss
multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier
preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier
self.accelerator.backward(preservation_loss)
loss = normal_loss + preservation_loss
loss = loss.clone().detach()
# require grad again so the backward wont fail
loss.requires_grad_(True)