From ee206cfa18b52f91b8b4cba9395c687f050d2c4e Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 22 Oct 2025 14:55:13 -0600 Subject: [PATCH] Added blank prompt preservation --- extensions_built_in/sd_trainer/SDTrainer.py | 53 ++++++++++++--- toolkit/config_modules.py | 9 ++- ui/src/app/jobs/new/SimpleJob.tsx | 74 ++++++++++++++++----- ui/src/app/jobs/new/options.ts | 1 + ui/src/docs.tsx | 30 +++++++++ ui/src/types.ts | 2 + version.py | 2 +- 7 files changed, 143 insertions(+), 28 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 03944e6c..178534aa 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -95,8 +95,13 @@ class SDTrainer(BaseSDTrainProcess): raise ValueError("diff_output_preservation requires a network to be set") if self.train_config.train_text_encoder: raise ValueError("diff_output_preservation is not supported with train_text_encoder") - - # always do a prior prediction when doing diff output preservation + + if self.train_config.blank_prompt_preservation: + if self.network_config is None: + raise ValueError("blank_prompt_preservation requires a network to be set") + + if self.train_config.blank_prompt_preservation or self.train_config.diff_output_preservation: + # always do a prior prediction when doing output preservation self.do_prior_prediction = True # store the loss target for a batch so we can use it in a loss @@ -343,6 +348,13 @@ class SDTrainer(BaseSDTrainProcess): self.sd.text_encoder_to("cpu") flush() + if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None: + # make sure we have this if not unloading + self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs).to( + self.device_torch, + dtype=self.sd.torch_dtype + ).detach() + if self.train_config.diffusion_feature_extractor_path is not None: vae = self.sd.vae # if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": @@ -1769,6 +1781,14 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.diff_output_preservation: prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) + if self.train_config.blank_prompt_preservation: + blank_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + prior_embeds_to_use = concat_prompt_embeds( + [blank_embeds] * noisy_latents.shape[0] + ) + prior_pred = self.get_prior_prediction( noisy_latents=noisy_latents, conditional_embeds=prior_embeds_to_use, @@ -1944,7 +1964,8 @@ class SDTrainer(BaseSDTrainProcess): prior_to_calculate_loss = prior_pred # if we are doing diff_output_preservation and not noing inverted masked prior # then we need to send none here so it will not target the prior - if self.train_config.diff_output_preservation and not do_inverted_masked_prior: + doing_preservation = self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation + if doing_preservation and not do_inverted_masked_prior: prior_to_calculate_loss = None loss = self.calculate_loss( @@ -1957,24 +1978,34 @@ class SDTrainer(BaseSDTrainProcess): prior_pred=prior_to_calculate_loss, ) - if self.train_config.diff_output_preservation: + if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation: # send the loss backwards otherwise checkpointing will fail self.accelerator.backward(loss) normal_loss = loss.detach() # dont send backward again - dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) - dop_pred = self.predict_noise( + with torch.no_grad(): + if self.train_config.diff_output_preservation: + preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) + elif self.train_config.blank_prompt_preservation: + blank_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + preservation_embeds = concat_prompt_embeds( + [blank_embeds] * noisy_latents.shape[0] + ) + preservation_pred = self.predict_noise( noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), timesteps=timesteps, - conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype), + conditional_embeds=preservation_embeds.to(self.device_torch, dtype=dtype), unconditional_embeds=unconditional_embeds, batch=batch, **pred_kwargs ) - dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier - self.accelerator.backward(dop_loss) - - loss = normal_loss + dop_loss + multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier + preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier + self.accelerator.backward(preservation_loss) + + loss = normal_loss + preservation_loss loss = loss.clone().detach() # require grad again so the backward wont fail loss.requires_grad_(True) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 44f47a71..6aeb9466 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -451,7 +451,11 @@ class TrainConfig: self.diff_output_preservation_multiplier = kwargs.get('diff_output_preservation_multiplier', 1.0) # If the trigger word is in the prompt, we will use this class name to replace it eg. "sks woman" -> "woman" self.diff_output_preservation_class = kwargs.get('diff_output_preservation_class', '') - + + # blank prompt preservation will preserve the model's knowledge of a blank prompt + self.blank_prompt_preservation = kwargs.get('blank_prompt_preservation', False) + self.blank_prompt_preservation_multiplier = kwargs.get('blank_prompt_preservation_multiplier', 1.0) + # legacy if match_adapter_assist and self.match_adapter_chance == 0.0: self.match_adapter_chance = 1.0 @@ -1318,5 +1322,8 @@ def validate_configs( if model_config.arch == 'qwen_image_edit': if train_config.unload_text_encoder: raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though") + + if train_config.diff_output_preservation and train_config.blank_prompt_preservation: + raise ValueError("Cannot use both differential output preservation and blank prompt preservation at the same time. Please set one of them to False.") diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index c13aaf4e..fa9d532a 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -215,12 +215,12 @@ export default function SimpleJob({ )} {modelArch?.additionalSections?.includes('model.qie.match_target_res') && ( - setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')} - /> + setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')} + /> )} {modelArch?.additionalSections?.includes('model.layer_offloading') && ( <> @@ -586,16 +586,27 @@ export default function SimpleJob({
+ {disableSections.includes('train.diff_output_preservation') || + disableSections.includes('train.blank_prompt_preservation') ? null : ( + + <> + + )} {disableSections.includes('train.diff_output_preservation') ? null : ( <> - - setJobConfig(value, 'config.process[0].train.diff_output_preservation')} - /> - + { + setJobConfig(value, 'config.process[0].train.diff_output_preservation'); + if (value && jobConfig.config.process[0].train.blank_prompt_preservation) { + // only one can be enabled at a time + setJobConfig(false, 'config.process[0].train.blank_prompt_preservation'); + } + }} + /> {jobConfig.config.process[0].train.diff_output_preservation && ( <> setJobConfig(value, 'config.process[0].train.diff_output_preservation_class') @@ -621,6 +632,39 @@ export default function SimpleJob({ )} )} + {disableSections.includes('train.blank_prompt_preservation') ? null : ( + <> + { + setJobConfig(value, 'config.process[0].train.blank_prompt_preservation'); + if (value && jobConfig.config.process[0].train.diff_output_preservation) { + // only one can be enabled at a time + setJobConfig(false, 'config.process[0].train.diff_output_preservation'); + } + }} + /> + {jobConfig.config.process[0].train.blank_prompt_preservation && ( + <> + + setJobConfig(value, 'config.process[0].train.blank_prompt_preservation_multiplier') + } + placeholder="eg. 1.0" + min={0} + /> + + )} + + )}
diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 07596a0a..0ef68e57 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -9,6 +9,7 @@ type DisableableSections = | 'network.conv' | 'trigger_word' | 'train.diff_output_preservation' + | 'train.blank_prompt_preservation' | 'train.unload_text_encoder' | 'slider'; diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index d3c7dafd..c7545793 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -228,6 +228,36 @@ const docs: { [key: string]: ConfigDoc } = { ), }, + 'train.diff_output_preservation': { + title: 'Differential Output Preservation', + description: ( + <> + Differential Output Preservation (DOP) is a technique to help preserve class of the trained concept during + training. For this, you must have a trigger word set to differentiate your concept from its class. For instance, + You may be training a woman named Alice. Your trigger word may be "Alice". The class is "woman", since Alice is + a woman. We want to teach the model to remember what it knows about the class "woman" while teaching it what is + different about Alice. During training, the trainer will make a prediction with your LoRA bypassed and your + trigger word in the prompt replaced with the class word. Making "photo of Alice" become "photo of woman". This + prediction is called the prior prediction. Each step, we will do the normal training step, but also do another + step with this prior prediction and the class prompt in order to teach our LoRA to preserve the knowledge of the + class. This should not only improve the performance of your trained concept, but also allow you to do things + like "Alice standing next to a woman" and not make both of the people look like Alice. + + ), + }, + 'train.blank_prompt_preservation': { + title: 'Blank Prompt Preservation', + description: ( + <> + Blank Prompt Preservation (BPP) is a technique to help preserve the current models knowledge when unprompted. + This will not only help the model become more flexible, but will also help the quality of your concept during + inference, especially when a model uses CFG (Classifier Free Guidance) on inference. At each step during + training, a prior prediction is made with a blank prompt and with the LoRA disabled. This prediction is then + used as a target on an additional training step with a blank prompt, to preserve the model's knowledge when no + prompt is given. This helps the model to not overfit to the prompt and retain its generalization capabilities. + + ), + }, }; export const getDoc = (key: string | null | undefined): ConfigDoc | null => { diff --git a/ui/src/types.ts b/ui/src/types.ts index a0b871e6..d01bd89f 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -135,6 +135,8 @@ export interface TrainConfig { diff_output_preservation: boolean; diff_output_preservation_multiplier: number; diff_output_preservation_class: string; + blank_prompt_preservation?: boolean; + blank_prompt_preservation_multiplier?: number; switch_boundary_every: number; loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped'; } diff --git a/version.py b/version.py index bfbfd0f9..2e6e6b89 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.7.1" \ No newline at end of file +VERSION = "0.7.2" \ No newline at end of file