mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added blank prompt preservation
This commit is contained in:
@@ -95,8 +95,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
raise ValueError("diff_output_preservation requires a network to be set")
|
||||
if self.train_config.train_text_encoder:
|
||||
raise ValueError("diff_output_preservation is not supported with train_text_encoder")
|
||||
|
||||
# always do a prior prediction when doing diff output preservation
|
||||
|
||||
if self.train_config.blank_prompt_preservation:
|
||||
if self.network_config is None:
|
||||
raise ValueError("blank_prompt_preservation requires a network to be set")
|
||||
|
||||
if self.train_config.blank_prompt_preservation or self.train_config.diff_output_preservation:
|
||||
# always do a prior prediction when doing output preservation
|
||||
self.do_prior_prediction = True
|
||||
|
||||
# store the loss target for a batch so we can use it in a loss
|
||||
@@ -343,6 +348,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.text_encoder_to("cpu")
|
||||
flush()
|
||||
|
||||
if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None:
|
||||
# make sure we have this if not unloading
|
||||
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs).to(
|
||||
self.device_torch,
|
||||
dtype=self.sd.torch_dtype
|
||||
).detach()
|
||||
|
||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||
vae = self.sd.vae
|
||||
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
||||
@@ -1769,6 +1781,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.diff_output_preservation:
|
||||
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
||||
|
||||
if self.train_config.blank_prompt_preservation:
|
||||
blank_embeds = self.cached_blank_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
prior_embeds_to_use = concat_prompt_embeds(
|
||||
[blank_embeds] * noisy_latents.shape[0]
|
||||
)
|
||||
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=prior_embeds_to_use,
|
||||
@@ -1944,7 +1964,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prior_to_calculate_loss = prior_pred
|
||||
# if we are doing diff_output_preservation and not noing inverted masked prior
|
||||
# then we need to send none here so it will not target the prior
|
||||
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
||||
doing_preservation = self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation
|
||||
if doing_preservation and not do_inverted_masked_prior:
|
||||
prior_to_calculate_loss = None
|
||||
|
||||
loss = self.calculate_loss(
|
||||
@@ -1957,24 +1978,34 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prior_pred=prior_to_calculate_loss,
|
||||
)
|
||||
|
||||
if self.train_config.diff_output_preservation:
|
||||
if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation:
|
||||
# send the loss backwards otherwise checkpointing will fail
|
||||
self.accelerator.backward(loss)
|
||||
normal_loss = loss.detach() # dont send backward again
|
||||
|
||||
dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
||||
dop_pred = self.predict_noise(
|
||||
with torch.no_grad():
|
||||
if self.train_config.diff_output_preservation:
|
||||
preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
||||
elif self.train_config.blank_prompt_preservation:
|
||||
blank_embeds = self.cached_blank_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
preservation_embeds = concat_prompt_embeds(
|
||||
[blank_embeds] * noisy_latents.shape[0]
|
||||
)
|
||||
preservation_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeds=preservation_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier
|
||||
self.accelerator.backward(dop_loss)
|
||||
|
||||
loss = normal_loss + dop_loss
|
||||
multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier
|
||||
preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier
|
||||
self.accelerator.backward(preservation_loss)
|
||||
|
||||
loss = normal_loss + preservation_loss
|
||||
loss = loss.clone().detach()
|
||||
# require grad again so the backward wont fail
|
||||
loss.requires_grad_(True)
|
||||
|
||||
Reference in New Issue
Block a user