mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added blank prompt preservation
This commit is contained in:
@@ -95,8 +95,13 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
raise ValueError("diff_output_preservation requires a network to be set")
|
raise ValueError("diff_output_preservation requires a network to be set")
|
||||||
if self.train_config.train_text_encoder:
|
if self.train_config.train_text_encoder:
|
||||||
raise ValueError("diff_output_preservation is not supported with train_text_encoder")
|
raise ValueError("diff_output_preservation is not supported with train_text_encoder")
|
||||||
|
|
||||||
# always do a prior prediction when doing diff output preservation
|
if self.train_config.blank_prompt_preservation:
|
||||||
|
if self.network_config is None:
|
||||||
|
raise ValueError("blank_prompt_preservation requires a network to be set")
|
||||||
|
|
||||||
|
if self.train_config.blank_prompt_preservation or self.train_config.diff_output_preservation:
|
||||||
|
# always do a prior prediction when doing output preservation
|
||||||
self.do_prior_prediction = True
|
self.do_prior_prediction = True
|
||||||
|
|
||||||
# store the loss target for a batch so we can use it in a loss
|
# store the loss target for a batch so we can use it in a loss
|
||||||
@@ -343,6 +348,13 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.sd.text_encoder_to("cpu")
|
self.sd.text_encoder_to("cpu")
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
|
if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None:
|
||||||
|
# make sure we have this if not unloading
|
||||||
|
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs).to(
|
||||||
|
self.device_torch,
|
||||||
|
dtype=self.sd.torch_dtype
|
||||||
|
).detach()
|
||||||
|
|
||||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||||
vae = self.sd.vae
|
vae = self.sd.vae
|
||||||
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
||||||
@@ -1769,6 +1781,14 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if self.train_config.diff_output_preservation:
|
if self.train_config.diff_output_preservation:
|
||||||
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
||||||
|
|
||||||
|
if self.train_config.blank_prompt_preservation:
|
||||||
|
blank_embeds = self.cached_blank_embeds.clone().detach().to(
|
||||||
|
self.device_torch, dtype=dtype
|
||||||
|
)
|
||||||
|
prior_embeds_to_use = concat_prompt_embeds(
|
||||||
|
[blank_embeds] * noisy_latents.shape[0]
|
||||||
|
)
|
||||||
|
|
||||||
prior_pred = self.get_prior_prediction(
|
prior_pred = self.get_prior_prediction(
|
||||||
noisy_latents=noisy_latents,
|
noisy_latents=noisy_latents,
|
||||||
conditional_embeds=prior_embeds_to_use,
|
conditional_embeds=prior_embeds_to_use,
|
||||||
@@ -1944,7 +1964,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
prior_to_calculate_loss = prior_pred
|
prior_to_calculate_loss = prior_pred
|
||||||
# if we are doing diff_output_preservation and not noing inverted masked prior
|
# if we are doing diff_output_preservation and not noing inverted masked prior
|
||||||
# then we need to send none here so it will not target the prior
|
# then we need to send none here so it will not target the prior
|
||||||
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
doing_preservation = self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation
|
||||||
|
if doing_preservation and not do_inverted_masked_prior:
|
||||||
prior_to_calculate_loss = None
|
prior_to_calculate_loss = None
|
||||||
|
|
||||||
loss = self.calculate_loss(
|
loss = self.calculate_loss(
|
||||||
@@ -1957,24 +1978,34 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
prior_pred=prior_to_calculate_loss,
|
prior_pred=prior_to_calculate_loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.train_config.diff_output_preservation:
|
if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation:
|
||||||
# send the loss backwards otherwise checkpointing will fail
|
# send the loss backwards otherwise checkpointing will fail
|
||||||
self.accelerator.backward(loss)
|
self.accelerator.backward(loss)
|
||||||
normal_loss = loss.detach() # dont send backward again
|
normal_loss = loss.detach() # dont send backward again
|
||||||
|
|
||||||
dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
with torch.no_grad():
|
||||||
dop_pred = self.predict_noise(
|
if self.train_config.diff_output_preservation:
|
||||||
|
preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
||||||
|
elif self.train_config.blank_prompt_preservation:
|
||||||
|
blank_embeds = self.cached_blank_embeds.clone().detach().to(
|
||||||
|
self.device_torch, dtype=dtype
|
||||||
|
)
|
||||||
|
preservation_embeds = concat_prompt_embeds(
|
||||||
|
[blank_embeds] * noisy_latents.shape[0]
|
||||||
|
)
|
||||||
|
preservation_pred = self.predict_noise(
|
||||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype),
|
conditional_embeds=preservation_embeds.to(self.device_torch, dtype=dtype),
|
||||||
unconditional_embeds=unconditional_embeds,
|
unconditional_embeds=unconditional_embeds,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
**pred_kwargs
|
**pred_kwargs
|
||||||
)
|
)
|
||||||
dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier
|
multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier
|
||||||
self.accelerator.backward(dop_loss)
|
preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier
|
||||||
|
self.accelerator.backward(preservation_loss)
|
||||||
loss = normal_loss + dop_loss
|
|
||||||
|
loss = normal_loss + preservation_loss
|
||||||
loss = loss.clone().detach()
|
loss = loss.clone().detach()
|
||||||
# require grad again so the backward wont fail
|
# require grad again so the backward wont fail
|
||||||
loss.requires_grad_(True)
|
loss.requires_grad_(True)
|
||||||
|
|||||||
@@ -451,7 +451,11 @@ class TrainConfig:
|
|||||||
self.diff_output_preservation_multiplier = kwargs.get('diff_output_preservation_multiplier', 1.0)
|
self.diff_output_preservation_multiplier = kwargs.get('diff_output_preservation_multiplier', 1.0)
|
||||||
# If the trigger word is in the prompt, we will use this class name to replace it eg. "sks woman" -> "woman"
|
# If the trigger word is in the prompt, we will use this class name to replace it eg. "sks woman" -> "woman"
|
||||||
self.diff_output_preservation_class = kwargs.get('diff_output_preservation_class', '')
|
self.diff_output_preservation_class = kwargs.get('diff_output_preservation_class', '')
|
||||||
|
|
||||||
|
# blank prompt preservation will preserve the model's knowledge of a blank prompt
|
||||||
|
self.blank_prompt_preservation = kwargs.get('blank_prompt_preservation', False)
|
||||||
|
self.blank_prompt_preservation_multiplier = kwargs.get('blank_prompt_preservation_multiplier', 1.0)
|
||||||
|
|
||||||
# legacy
|
# legacy
|
||||||
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
||||||
self.match_adapter_chance = 1.0
|
self.match_adapter_chance = 1.0
|
||||||
@@ -1318,5 +1322,8 @@ def validate_configs(
|
|||||||
if model_config.arch == 'qwen_image_edit':
|
if model_config.arch == 'qwen_image_edit':
|
||||||
if train_config.unload_text_encoder:
|
if train_config.unload_text_encoder:
|
||||||
raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though")
|
raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though")
|
||||||
|
|
||||||
|
if train_config.diff_output_preservation and train_config.blank_prompt_preservation:
|
||||||
|
raise ValueError("Cannot use both differential output preservation and blank prompt preservation at the same time. Please set one of them to False.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -215,12 +215,12 @@ export default function SimpleJob({
|
|||||||
</FormGroup>
|
</FormGroup>
|
||||||
)}
|
)}
|
||||||
{modelArch?.additionalSections?.includes('model.qie.match_target_res') && (
|
{modelArch?.additionalSections?.includes('model.qie.match_target_res') && (
|
||||||
<Checkbox
|
<Checkbox
|
||||||
label="Match Target Res"
|
label="Match Target Res"
|
||||||
docKey="model.qie.match_target_res"
|
docKey="model.qie.match_target_res"
|
||||||
checked={jobConfig.config.process[0].model.model_kwargs.match_target_res}
|
checked={jobConfig.config.process[0].model.model_kwargs.match_target_res}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')}
|
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{modelArch?.additionalSections?.includes('model.layer_offloading') && (
|
{modelArch?.additionalSections?.includes('model.layer_offloading') && (
|
||||||
<>
|
<>
|
||||||
@@ -586,16 +586,27 @@ export default function SimpleJob({
|
|||||||
</FormGroup>
|
</FormGroup>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
|
{disableSections.includes('train.diff_output_preservation') ||
|
||||||
|
disableSections.includes('train.blank_prompt_preservation') ? null : (
|
||||||
|
<FormGroup label="Regularization">
|
||||||
|
<></>
|
||||||
|
</FormGroup>
|
||||||
|
)}
|
||||||
{disableSections.includes('train.diff_output_preservation') ? null : (
|
{disableSections.includes('train.diff_output_preservation') ? null : (
|
||||||
<>
|
<>
|
||||||
<FormGroup label="Regularization">
|
<Checkbox
|
||||||
<Checkbox
|
label="Differential Output Preservation"
|
||||||
label="Differential Output Preservation"
|
docKey={'train.diff_output_preservation'}
|
||||||
className="pt-1"
|
className="pt-1"
|
||||||
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
onChange={value => {
|
||||||
/>
|
setJobConfig(value, 'config.process[0].train.diff_output_preservation');
|
||||||
</FormGroup>
|
if (value && jobConfig.config.process[0].train.blank_prompt_preservation) {
|
||||||
|
// only one can be enabled at a time
|
||||||
|
setJobConfig(false, 'config.process[0].train.blank_prompt_preservation');
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
{jobConfig.config.process[0].train.diff_output_preservation && (
|
{jobConfig.config.process[0].train.diff_output_preservation && (
|
||||||
<>
|
<>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
@@ -610,7 +621,7 @@ export default function SimpleJob({
|
|||||||
/>
|
/>
|
||||||
<TextInput
|
<TextInput
|
||||||
label="DOP Preservation Class"
|
label="DOP Preservation Class"
|
||||||
className="pt-2"
|
className="pt-2 pb-4"
|
||||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||||
onChange={value =>
|
onChange={value =>
|
||||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')
|
setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')
|
||||||
@@ -621,6 +632,39 @@ export default function SimpleJob({
|
|||||||
)}
|
)}
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
{disableSections.includes('train.blank_prompt_preservation') ? null : (
|
||||||
|
<>
|
||||||
|
<Checkbox
|
||||||
|
label="Blank Prompt Preservation"
|
||||||
|
docKey={'train.blank_prompt_preservation'}
|
||||||
|
className="pt-1"
|
||||||
|
checked={jobConfig.config.process[0].train.blank_prompt_preservation || false}
|
||||||
|
onChange={value => {
|
||||||
|
setJobConfig(value, 'config.process[0].train.blank_prompt_preservation');
|
||||||
|
if (value && jobConfig.config.process[0].train.diff_output_preservation) {
|
||||||
|
// only one can be enabled at a time
|
||||||
|
setJobConfig(false, 'config.process[0].train.diff_output_preservation');
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{jobConfig.config.process[0].train.blank_prompt_preservation && (
|
||||||
|
<>
|
||||||
|
<NumberInput
|
||||||
|
label="BPP Loss Multiplier"
|
||||||
|
className="pt-2"
|
||||||
|
value={
|
||||||
|
(jobConfig.config.process[0].train.blank_prompt_preservation_multiplier as number) || 1.0
|
||||||
|
}
|
||||||
|
onChange={value =>
|
||||||
|
setJobConfig(value, 'config.process[0].train.blank_prompt_preservation_multiplier')
|
||||||
|
}
|
||||||
|
placeholder="eg. 1.0"
|
||||||
|
min={0}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ type DisableableSections =
|
|||||||
| 'network.conv'
|
| 'network.conv'
|
||||||
| 'trigger_word'
|
| 'trigger_word'
|
||||||
| 'train.diff_output_preservation'
|
| 'train.diff_output_preservation'
|
||||||
|
| 'train.blank_prompt_preservation'
|
||||||
| 'train.unload_text_encoder'
|
| 'train.unload_text_encoder'
|
||||||
| 'slider';
|
| 'slider';
|
||||||
|
|
||||||
|
|||||||
@@ -228,6 +228,36 @@ const docs: { [key: string]: ConfigDoc } = {
|
|||||||
</>
|
</>
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
'train.diff_output_preservation': {
|
||||||
|
title: 'Differential Output Preservation',
|
||||||
|
description: (
|
||||||
|
<>
|
||||||
|
Differential Output Preservation (DOP) is a technique to help preserve class of the trained concept during
|
||||||
|
training. For this, you must have a trigger word set to differentiate your concept from its class. For instance,
|
||||||
|
You may be training a woman named Alice. Your trigger word may be "Alice". The class is "woman", since Alice is
|
||||||
|
a woman. We want to teach the model to remember what it knows about the class "woman" while teaching it what is
|
||||||
|
different about Alice. During training, the trainer will make a prediction with your LoRA bypassed and your
|
||||||
|
trigger word in the prompt replaced with the class word. Making "photo of Alice" become "photo of woman". This
|
||||||
|
prediction is called the prior prediction. Each step, we will do the normal training step, but also do another
|
||||||
|
step with this prior prediction and the class prompt in order to teach our LoRA to preserve the knowledge of the
|
||||||
|
class. This should not only improve the performance of your trained concept, but also allow you to do things
|
||||||
|
like "Alice standing next to a woman" and not make both of the people look like Alice.
|
||||||
|
</>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
'train.blank_prompt_preservation': {
|
||||||
|
title: 'Blank Prompt Preservation',
|
||||||
|
description: (
|
||||||
|
<>
|
||||||
|
Blank Prompt Preservation (BPP) is a technique to help preserve the current models knowledge when unprompted.
|
||||||
|
This will not only help the model become more flexible, but will also help the quality of your concept during
|
||||||
|
inference, especially when a model uses CFG (Classifier Free Guidance) on inference. At each step during
|
||||||
|
training, a prior prediction is made with a blank prompt and with the LoRA disabled. This prediction is then
|
||||||
|
used as a target on an additional training step with a blank prompt, to preserve the model's knowledge when no
|
||||||
|
prompt is given. This helps the model to not overfit to the prompt and retain its generalization capabilities.
|
||||||
|
</>
|
||||||
|
),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||||
|
|||||||
@@ -135,6 +135,8 @@ export interface TrainConfig {
|
|||||||
diff_output_preservation: boolean;
|
diff_output_preservation: boolean;
|
||||||
diff_output_preservation_multiplier: number;
|
diff_output_preservation_multiplier: number;
|
||||||
diff_output_preservation_class: string;
|
diff_output_preservation_class: string;
|
||||||
|
blank_prompt_preservation?: boolean;
|
||||||
|
blank_prompt_preservation_multiplier?: number;
|
||||||
switch_boundary_every: number;
|
switch_boundary_every: number;
|
||||||
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
|
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.7.1"
|
VERSION = "0.7.2"
|
||||||
Reference in New Issue
Block a user