mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added blank prompt preservation
This commit is contained in:
@@ -95,8 +95,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
raise ValueError("diff_output_preservation requires a network to be set")
|
||||
if self.train_config.train_text_encoder:
|
||||
raise ValueError("diff_output_preservation is not supported with train_text_encoder")
|
||||
|
||||
# always do a prior prediction when doing diff output preservation
|
||||
|
||||
if self.train_config.blank_prompt_preservation:
|
||||
if self.network_config is None:
|
||||
raise ValueError("blank_prompt_preservation requires a network to be set")
|
||||
|
||||
if self.train_config.blank_prompt_preservation or self.train_config.diff_output_preservation:
|
||||
# always do a prior prediction when doing output preservation
|
||||
self.do_prior_prediction = True
|
||||
|
||||
# store the loss target for a batch so we can use it in a loss
|
||||
@@ -343,6 +348,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.text_encoder_to("cpu")
|
||||
flush()
|
||||
|
||||
if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None:
|
||||
# make sure we have this if not unloading
|
||||
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs).to(
|
||||
self.device_torch,
|
||||
dtype=self.sd.torch_dtype
|
||||
).detach()
|
||||
|
||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||
vae = self.sd.vae
|
||||
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
||||
@@ -1769,6 +1781,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.diff_output_preservation:
|
||||
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
||||
|
||||
if self.train_config.blank_prompt_preservation:
|
||||
blank_embeds = self.cached_blank_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
prior_embeds_to_use = concat_prompt_embeds(
|
||||
[blank_embeds] * noisy_latents.shape[0]
|
||||
)
|
||||
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=prior_embeds_to_use,
|
||||
@@ -1944,7 +1964,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prior_to_calculate_loss = prior_pred
|
||||
# if we are doing diff_output_preservation and not noing inverted masked prior
|
||||
# then we need to send none here so it will not target the prior
|
||||
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
||||
doing_preservation = self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation
|
||||
if doing_preservation and not do_inverted_masked_prior:
|
||||
prior_to_calculate_loss = None
|
||||
|
||||
loss = self.calculate_loss(
|
||||
@@ -1957,24 +1978,34 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prior_pred=prior_to_calculate_loss,
|
||||
)
|
||||
|
||||
if self.train_config.diff_output_preservation:
|
||||
if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation:
|
||||
# send the loss backwards otherwise checkpointing will fail
|
||||
self.accelerator.backward(loss)
|
||||
normal_loss = loss.detach() # dont send backward again
|
||||
|
||||
dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
||||
dop_pred = self.predict_noise(
|
||||
with torch.no_grad():
|
||||
if self.train_config.diff_output_preservation:
|
||||
preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
||||
elif self.train_config.blank_prompt_preservation:
|
||||
blank_embeds = self.cached_blank_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
preservation_embeds = concat_prompt_embeds(
|
||||
[blank_embeds] * noisy_latents.shape[0]
|
||||
)
|
||||
preservation_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeds=preservation_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier
|
||||
self.accelerator.backward(dop_loss)
|
||||
|
||||
loss = normal_loss + dop_loss
|
||||
multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier
|
||||
preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier
|
||||
self.accelerator.backward(preservation_loss)
|
||||
|
||||
loss = normal_loss + preservation_loss
|
||||
loss = loss.clone().detach()
|
||||
# require grad again so the backward wont fail
|
||||
loss.requires_grad_(True)
|
||||
|
||||
@@ -451,7 +451,11 @@ class TrainConfig:
|
||||
self.diff_output_preservation_multiplier = kwargs.get('diff_output_preservation_multiplier', 1.0)
|
||||
# If the trigger word is in the prompt, we will use this class name to replace it eg. "sks woman" -> "woman"
|
||||
self.diff_output_preservation_class = kwargs.get('diff_output_preservation_class', '')
|
||||
|
||||
|
||||
# blank prompt preservation will preserve the model's knowledge of a blank prompt
|
||||
self.blank_prompt_preservation = kwargs.get('blank_prompt_preservation', False)
|
||||
self.blank_prompt_preservation_multiplier = kwargs.get('blank_prompt_preservation_multiplier', 1.0)
|
||||
|
||||
# legacy
|
||||
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
||||
self.match_adapter_chance = 1.0
|
||||
@@ -1318,5 +1322,8 @@ def validate_configs(
|
||||
if model_config.arch == 'qwen_image_edit':
|
||||
if train_config.unload_text_encoder:
|
||||
raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though")
|
||||
|
||||
if train_config.diff_output_preservation and train_config.blank_prompt_preservation:
|
||||
raise ValueError("Cannot use both differential output preservation and blank prompt preservation at the same time. Please set one of them to False.")
|
||||
|
||||
|
||||
|
||||
@@ -215,12 +215,12 @@ export default function SimpleJob({
|
||||
</FormGroup>
|
||||
)}
|
||||
{modelArch?.additionalSections?.includes('model.qie.match_target_res') && (
|
||||
<Checkbox
|
||||
label="Match Target Res"
|
||||
docKey="model.qie.match_target_res"
|
||||
checked={jobConfig.config.process[0].model.model_kwargs.match_target_res}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Match Target Res"
|
||||
docKey="model.qie.match_target_res"
|
||||
checked={jobConfig.config.process[0].model.model_kwargs.match_target_res}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')}
|
||||
/>
|
||||
)}
|
||||
{modelArch?.additionalSections?.includes('model.layer_offloading') && (
|
||||
<>
|
||||
@@ -586,16 +586,27 @@ export default function SimpleJob({
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
{disableSections.includes('train.diff_output_preservation') ||
|
||||
disableSections.includes('train.blank_prompt_preservation') ? null : (
|
||||
<FormGroup label="Regularization">
|
||||
<></>
|
||||
</FormGroup>
|
||||
)}
|
||||
{disableSections.includes('train.diff_output_preservation') ? null : (
|
||||
<>
|
||||
<FormGroup label="Regularization">
|
||||
<Checkbox
|
||||
label="Differential Output Preservation"
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<Checkbox
|
||||
label="Differential Output Preservation"
|
||||
docKey={'train.diff_output_preservation'}
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation');
|
||||
if (value && jobConfig.config.process[0].train.blank_prompt_preservation) {
|
||||
// only one can be enabled at a time
|
||||
setJobConfig(false, 'config.process[0].train.blank_prompt_preservation');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
{jobConfig.config.process[0].train.diff_output_preservation && (
|
||||
<>
|
||||
<NumberInput
|
||||
@@ -610,7 +621,7 @@ export default function SimpleJob({
|
||||
/>
|
||||
<TextInput
|
||||
label="DOP Preservation Class"
|
||||
className="pt-2"
|
||||
className="pt-2 pb-4"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')
|
||||
@@ -621,6 +632,39 @@ export default function SimpleJob({
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{disableSections.includes('train.blank_prompt_preservation') ? null : (
|
||||
<>
|
||||
<Checkbox
|
||||
label="Blank Prompt Preservation"
|
||||
docKey={'train.blank_prompt_preservation'}
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.blank_prompt_preservation || false}
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.blank_prompt_preservation');
|
||||
if (value && jobConfig.config.process[0].train.diff_output_preservation) {
|
||||
// only one can be enabled at a time
|
||||
setJobConfig(false, 'config.process[0].train.diff_output_preservation');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
{jobConfig.config.process[0].train.blank_prompt_preservation && (
|
||||
<>
|
||||
<NumberInput
|
||||
label="BPP Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={
|
||||
(jobConfig.config.process[0].train.blank_prompt_preservation_multiplier as number) || 1.0
|
||||
}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.blank_prompt_preservation_multiplier')
|
||||
}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
@@ -9,6 +9,7 @@ type DisableableSections =
|
||||
| 'network.conv'
|
||||
| 'trigger_word'
|
||||
| 'train.diff_output_preservation'
|
||||
| 'train.blank_prompt_preservation'
|
||||
| 'train.unload_text_encoder'
|
||||
| 'slider';
|
||||
|
||||
|
||||
@@ -228,6 +228,36 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'train.diff_output_preservation': {
|
||||
title: 'Differential Output Preservation',
|
||||
description: (
|
||||
<>
|
||||
Differential Output Preservation (DOP) is a technique to help preserve class of the trained concept during
|
||||
training. For this, you must have a trigger word set to differentiate your concept from its class. For instance,
|
||||
You may be training a woman named Alice. Your trigger word may be "Alice". The class is "woman", since Alice is
|
||||
a woman. We want to teach the model to remember what it knows about the class "woman" while teaching it what is
|
||||
different about Alice. During training, the trainer will make a prediction with your LoRA bypassed and your
|
||||
trigger word in the prompt replaced with the class word. Making "photo of Alice" become "photo of woman". This
|
||||
prediction is called the prior prediction. Each step, we will do the normal training step, but also do another
|
||||
step with this prior prediction and the class prompt in order to teach our LoRA to preserve the knowledge of the
|
||||
class. This should not only improve the performance of your trained concept, but also allow you to do things
|
||||
like "Alice standing next to a woman" and not make both of the people look like Alice.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'train.blank_prompt_preservation': {
|
||||
title: 'Blank Prompt Preservation',
|
||||
description: (
|
||||
<>
|
||||
Blank Prompt Preservation (BPP) is a technique to help preserve the current models knowledge when unprompted.
|
||||
This will not only help the model become more flexible, but will also help the quality of your concept during
|
||||
inference, especially when a model uses CFG (Classifier Free Guidance) on inference. At each step during
|
||||
training, a prior prediction is made with a blank prompt and with the LoRA disabled. This prediction is then
|
||||
used as a target on an additional training step with a blank prompt, to preserve the model's knowledge when no
|
||||
prompt is given. This helps the model to not overfit to the prompt and retain its generalization capabilities.
|
||||
</>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||
|
||||
@@ -135,6 +135,8 @@ export interface TrainConfig {
|
||||
diff_output_preservation: boolean;
|
||||
diff_output_preservation_multiplier: number;
|
||||
diff_output_preservation_class: string;
|
||||
blank_prompt_preservation?: boolean;
|
||||
blank_prompt_preservation_multiplier?: number;
|
||||
switch_boundary_every: number;
|
||||
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.7.1"
|
||||
VERSION = "0.7.2"
|
||||
Reference in New Issue
Block a user