Added blank prompt preservation

This commit is contained in:
Jaret Burkett
2025-10-22 14:55:13 -06:00
parent ca57ffc270
commit ee206cfa18
7 changed files with 143 additions and 28 deletions

View File

@@ -95,8 +95,13 @@ class SDTrainer(BaseSDTrainProcess):
raise ValueError("diff_output_preservation requires a network to be set")
if self.train_config.train_text_encoder:
raise ValueError("diff_output_preservation is not supported with train_text_encoder")
# always do a prior prediction when doing diff output preservation
if self.train_config.blank_prompt_preservation:
if self.network_config is None:
raise ValueError("blank_prompt_preservation requires a network to be set")
if self.train_config.blank_prompt_preservation or self.train_config.diff_output_preservation:
# always do a prior prediction when doing output preservation
self.do_prior_prediction = True
# store the loss target for a batch so we can use it in a loss
@@ -343,6 +348,13 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.text_encoder_to("cpu")
flush()
if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None:
# make sure we have this if not unloading
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs).to(
self.device_torch,
dtype=self.sd.torch_dtype
).detach()
if self.train_config.diffusion_feature_extractor_path is not None:
vae = self.sd.vae
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
@@ -1769,6 +1781,14 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.diff_output_preservation:
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
if self.train_config.blank_prompt_preservation:
blank_embeds = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
prior_embeds_to_use = concat_prompt_embeds(
[blank_embeds] * noisy_latents.shape[0]
)
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
conditional_embeds=prior_embeds_to_use,
@@ -1944,7 +1964,8 @@ class SDTrainer(BaseSDTrainProcess):
prior_to_calculate_loss = prior_pred
# if we are doing diff_output_preservation and not noing inverted masked prior
# then we need to send none here so it will not target the prior
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
doing_preservation = self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation
if doing_preservation and not do_inverted_masked_prior:
prior_to_calculate_loss = None
loss = self.calculate_loss(
@@ -1957,24 +1978,34 @@ class SDTrainer(BaseSDTrainProcess):
prior_pred=prior_to_calculate_loss,
)
if self.train_config.diff_output_preservation:
if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation:
# send the loss backwards otherwise checkpointing will fail
self.accelerator.backward(loss)
normal_loss = loss.detach() # dont send backward again
dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
dop_pred = self.predict_noise(
with torch.no_grad():
if self.train_config.diff_output_preservation:
preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
elif self.train_config.blank_prompt_preservation:
blank_embeds = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
preservation_embeds = concat_prompt_embeds(
[blank_embeds] * noisy_latents.shape[0]
)
preservation_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype),
conditional_embeds=preservation_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier
self.accelerator.backward(dop_loss)
loss = normal_loss + dop_loss
multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier
preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier
self.accelerator.backward(preservation_loss)
loss = normal_loss + preservation_loss
loss = loss.clone().detach()
# require grad again so the backward wont fail
loss.requires_grad_(True)

View File

@@ -451,7 +451,11 @@ class TrainConfig:
self.diff_output_preservation_multiplier = kwargs.get('diff_output_preservation_multiplier', 1.0)
# If the trigger word is in the prompt, we will use this class name to replace it eg. "sks woman" -> "woman"
self.diff_output_preservation_class = kwargs.get('diff_output_preservation_class', '')
# blank prompt preservation will preserve the model's knowledge of a blank prompt
self.blank_prompt_preservation = kwargs.get('blank_prompt_preservation', False)
self.blank_prompt_preservation_multiplier = kwargs.get('blank_prompt_preservation_multiplier', 1.0)
# legacy
if match_adapter_assist and self.match_adapter_chance == 0.0:
self.match_adapter_chance = 1.0
@@ -1318,5 +1322,8 @@ def validate_configs(
if model_config.arch == 'qwen_image_edit':
if train_config.unload_text_encoder:
raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though")
if train_config.diff_output_preservation and train_config.blank_prompt_preservation:
raise ValueError("Cannot use both differential output preservation and blank prompt preservation at the same time. Please set one of them to False.")

View File

@@ -215,12 +215,12 @@ export default function SimpleJob({
</FormGroup>
)}
{modelArch?.additionalSections?.includes('model.qie.match_target_res') && (
<Checkbox
label="Match Target Res"
docKey="model.qie.match_target_res"
checked={jobConfig.config.process[0].model.model_kwargs.match_target_res}
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')}
/>
<Checkbox
label="Match Target Res"
docKey="model.qie.match_target_res"
checked={jobConfig.config.process[0].model.model_kwargs.match_target_res}
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')}
/>
)}
{modelArch?.additionalSections?.includes('model.layer_offloading') && (
<>
@@ -586,16 +586,27 @@ export default function SimpleJob({
</FormGroup>
</div>
<div>
{disableSections.includes('train.diff_output_preservation') ||
disableSections.includes('train.blank_prompt_preservation') ? null : (
<FormGroup label="Regularization">
<></>
</FormGroup>
)}
{disableSections.includes('train.diff_output_preservation') ? null : (
<>
<FormGroup label="Regularization">
<Checkbox
label="Differential Output Preservation"
className="pt-1"
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
/>
</FormGroup>
<Checkbox
label="Differential Output Preservation"
docKey={'train.diff_output_preservation'}
className="pt-1"
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
onChange={value => {
setJobConfig(value, 'config.process[0].train.diff_output_preservation');
if (value && jobConfig.config.process[0].train.blank_prompt_preservation) {
// only one can be enabled at a time
setJobConfig(false, 'config.process[0].train.blank_prompt_preservation');
}
}}
/>
{jobConfig.config.process[0].train.diff_output_preservation && (
<>
<NumberInput
@@ -610,7 +621,7 @@ export default function SimpleJob({
/>
<TextInput
label="DOP Preservation Class"
className="pt-2"
className="pt-2 pb-4"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
onChange={value =>
setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')
@@ -621,6 +632,39 @@ export default function SimpleJob({
)}
</>
)}
{disableSections.includes('train.blank_prompt_preservation') ? null : (
<>
<Checkbox
label="Blank Prompt Preservation"
docKey={'train.blank_prompt_preservation'}
className="pt-1"
checked={jobConfig.config.process[0].train.blank_prompt_preservation || false}
onChange={value => {
setJobConfig(value, 'config.process[0].train.blank_prompt_preservation');
if (value && jobConfig.config.process[0].train.diff_output_preservation) {
// only one can be enabled at a time
setJobConfig(false, 'config.process[0].train.diff_output_preservation');
}
}}
/>
{jobConfig.config.process[0].train.blank_prompt_preservation && (
<>
<NumberInput
label="BPP Loss Multiplier"
className="pt-2"
value={
(jobConfig.config.process[0].train.blank_prompt_preservation_multiplier as number) || 1.0
}
onChange={value =>
setJobConfig(value, 'config.process[0].train.blank_prompt_preservation_multiplier')
}
placeholder="eg. 1.0"
min={0}
/>
</>
)}
</>
)}
</div>
</div>
</Card>

View File

@@ -9,6 +9,7 @@ type DisableableSections =
| 'network.conv'
| 'trigger_word'
| 'train.diff_output_preservation'
| 'train.blank_prompt_preservation'
| 'train.unload_text_encoder'
| 'slider';

View File

@@ -228,6 +228,36 @@ const docs: { [key: string]: ConfigDoc } = {
</>
),
},
'train.diff_output_preservation': {
title: 'Differential Output Preservation',
description: (
<>
Differential Output Preservation (DOP) is a technique to help preserve class of the trained concept during
training. For this, you must have a trigger word set to differentiate your concept from its class. For instance,
You may be training a woman named Alice. Your trigger word may be "Alice". The class is "woman", since Alice is
a woman. We want to teach the model to remember what it knows about the class "woman" while teaching it what is
different about Alice. During training, the trainer will make a prediction with your LoRA bypassed and your
trigger word in the prompt replaced with the class word. Making "photo of Alice" become "photo of woman". This
prediction is called the prior prediction. Each step, we will do the normal training step, but also do another
step with this prior prediction and the class prompt in order to teach our LoRA to preserve the knowledge of the
class. This should not only improve the performance of your trained concept, but also allow you to do things
like "Alice standing next to a woman" and not make both of the people look like Alice.
</>
),
},
'train.blank_prompt_preservation': {
title: 'Blank Prompt Preservation',
description: (
<>
Blank Prompt Preservation (BPP) is a technique to help preserve the current models knowledge when unprompted.
This will not only help the model become more flexible, but will also help the quality of your concept during
inference, especially when a model uses CFG (Classifier Free Guidance) on inference. At each step during
training, a prior prediction is made with a blank prompt and with the LoRA disabled. This prediction is then
used as a target on an additional training step with a blank prompt, to preserve the model's knowledge when no
prompt is given. This helps the model to not overfit to the prompt and retain its generalization capabilities.
</>
),
},
};
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {

View File

@@ -135,6 +135,8 @@ export interface TrainConfig {
diff_output_preservation: boolean;
diff_output_preservation_multiplier: number;
diff_output_preservation_class: string;
blank_prompt_preservation?: boolean;
blank_prompt_preservation_multiplier?: number;
switch_boundary_every: number;
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
}

View File

@@ -1 +1 @@
VERSION = "0.7.1"
VERSION = "0.7.2"