mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added Differential Output Preservation Loss to trainer and ui
This commit is contained in:
@@ -78,8 +78,20 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
self.cached_blank_embeds: Optional[PromptEmbeds] = None
|
self.cached_blank_embeds: Optional[PromptEmbeds] = None
|
||||||
self.cached_trigger_embeds: Optional[PromptEmbeds] = None
|
self.cached_trigger_embeds: Optional[PromptEmbeds] = None
|
||||||
|
self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None
|
||||||
|
|
||||||
self.dfe: Optional[DiffusionFeatureExtractor] = None
|
self.dfe: Optional[DiffusionFeatureExtractor] = None
|
||||||
|
|
||||||
|
if self.train_config.diff_output_preservation:
|
||||||
|
if self.trigger_word is None:
|
||||||
|
raise ValueError("diff_output_preservation requires a trigger_word to be set")
|
||||||
|
if self.network_config is None:
|
||||||
|
raise ValueError("diff_output_preservation requires a network to be set")
|
||||||
|
if self.train_config.train_text_encoder:
|
||||||
|
raise ValueError("diff_output_preservation is not supported with train_text_encoder")
|
||||||
|
|
||||||
|
# always do a prior prediction when doing diff output preservation
|
||||||
|
self.do_prior_prediction = True
|
||||||
|
|
||||||
|
|
||||||
def before_model_load(self):
|
def before_model_load(self):
|
||||||
@@ -176,6 +188,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.cached_blank_embeds = self.sd.encode_prompt("")
|
self.cached_blank_embeds = self.sd.encode_prompt("")
|
||||||
if self.trigger_word is not None:
|
if self.trigger_word is not None:
|
||||||
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
|
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
|
||||||
|
if self.train_config.diff_output_preservation:
|
||||||
|
self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class)
|
||||||
|
|
||||||
# move back to cpu
|
# move back to cpu
|
||||||
self.sd.text_encoder_to('cpu')
|
self.sd.text_encoder_to('cpu')
|
||||||
@@ -536,6 +550,19 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
|
|
||||||
return loss + additional_loss
|
return loss + additional_loss
|
||||||
|
|
||||||
|
def get_diff_output_preservation_loss(
|
||||||
|
self,
|
||||||
|
noise_pred: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
noisy_latents: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
batch: 'DataLoaderBatchDTO',
|
||||||
|
mask_multiplier: Union[torch.Tensor, float] = 1.0,
|
||||||
|
prior_pred: Union[torch.Tensor, None] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
loss_target = self.train_config.loss_target
|
||||||
|
|
||||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||||
return batch
|
return batch
|
||||||
@@ -872,8 +899,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
was_adapter_active = self.adapter.is_active
|
was_adapter_active = self.adapter.is_active
|
||||||
self.adapter.is_active = False
|
self.adapter.is_active = False
|
||||||
|
|
||||||
if self.train_config.unload_text_encoder:
|
if self.train_config.unload_text_encoder and self.adapter is not None:
|
||||||
raise ValueError("Prior predictions currently do not support unloading text encoder")
|
raise ValueError("Prior predictions currently do not support unloading text encoder with adapter")
|
||||||
# do a prediction here so we can match its output with network multiplier set to 0.0
|
# do a prediction here so we can match its output with network multiplier set to 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
@@ -1336,7 +1363,16 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
if isinstance(self.adapter, CustomAdapter):
|
if isinstance(self.adapter, CustomAdapter):
|
||||||
self.adapter.is_unconditional_run = False
|
self.adapter.is_unconditional_run = False
|
||||||
|
|
||||||
|
if self.train_config.diff_output_preservation:
|
||||||
|
dop_prompts = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in conditioned_prompts]
|
||||||
|
dop_prompts_2 = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in prompt_2]
|
||||||
|
self.diff_output_preservation_embeds = self.sd.encode_prompt(
|
||||||
|
dop_prompts, dop_prompts_2,
|
||||||
|
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||||
|
long_prompts=self.do_long_prompts).to(
|
||||||
|
self.device_torch,
|
||||||
|
dtype=dtype)
|
||||||
# detach the embeddings
|
# detach the embeddings
|
||||||
conditional_embeds = conditional_embeds.detach()
|
conditional_embeds = conditional_embeds.detach()
|
||||||
if self.train_config.do_cfg:
|
if self.train_config.do_cfg:
|
||||||
@@ -1524,9 +1560,14 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if ((
|
if ((
|
||||||
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
|
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
|
||||||
with self.timer('prior predict'):
|
with self.timer('prior predict'):
|
||||||
|
prior_embeds_to_use = conditional_embeds
|
||||||
|
# use diff_output_preservation embeds if doing dfe
|
||||||
|
if self.train_config.diff_output_preservation:
|
||||||
|
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
||||||
|
|
||||||
prior_pred = self.get_prior_prediction(
|
prior_pred = self.get_prior_prediction(
|
||||||
noisy_latents=noisy_latents,
|
noisy_latents=noisy_latents,
|
||||||
conditional_embeds=conditional_embeds,
|
conditional_embeds=prior_embeds_to_use,
|
||||||
match_adapter_assist=match_adapter_assist,
|
match_adapter_assist=match_adapter_assist,
|
||||||
network_weight_list=network_weight_list,
|
network_weight_list=network_weight_list,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
@@ -1627,6 +1668,12 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
with self.timer('calculate_loss'):
|
with self.timer('calculate_loss'):
|
||||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||||
|
prior_to_calculate_loss = prior_pred
|
||||||
|
# if we are doing diff_output_preservation and not noing inverted masked prior
|
||||||
|
# then we need to send none here so it will not target the prior
|
||||||
|
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
||||||
|
prior_to_calculate_loss = None
|
||||||
|
|
||||||
loss = self.calculate_loss(
|
loss = self.calculate_loss(
|
||||||
noise_pred=noise_pred,
|
noise_pred=noise_pred,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
@@ -1634,8 +1681,30 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
mask_multiplier=mask_multiplier,
|
mask_multiplier=mask_multiplier,
|
||||||
prior_pred=prior_pred,
|
prior_pred=prior_to_calculate_loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.train_config.diff_output_preservation:
|
||||||
|
# send the loss backwards otherwise checkpointing will fail
|
||||||
|
self.accelerator.backward(loss)
|
||||||
|
normal_loss = loss.detach() # dont send backward again
|
||||||
|
|
||||||
|
dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
|
||||||
|
dop_pred = self.predict_noise(
|
||||||
|
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||||
|
timesteps=timesteps,
|
||||||
|
conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype),
|
||||||
|
unconditional_embeds=unconditional_embeds,
|
||||||
|
**pred_kwargs
|
||||||
|
)
|
||||||
|
dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier
|
||||||
|
self.accelerator.backward(dop_loss)
|
||||||
|
|
||||||
|
loss = normal_loss + dop_loss
|
||||||
|
loss = loss.clone().detach()
|
||||||
|
# require grad again so the backward wont fail
|
||||||
|
loss.requires_grad_(True)
|
||||||
|
|
||||||
# check if nan
|
# check if nan
|
||||||
if torch.isnan(loss):
|
if torch.isnan(loss):
|
||||||
print_acc("loss is nan")
|
print_acc("loss is nan")
|
||||||
|
|||||||
@@ -341,6 +341,12 @@ class TrainConfig:
|
|||||||
# unmasked reign. It is unmasked regularization basically
|
# unmasked reign. It is unmasked regularization basically
|
||||||
self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False)
|
self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False)
|
||||||
self.inverted_mask_prior_multiplier = kwargs.get('inverted_mask_prior_multiplier', 0.5)
|
self.inverted_mask_prior_multiplier = kwargs.get('inverted_mask_prior_multiplier', 0.5)
|
||||||
|
|
||||||
|
# DOP will will run the same image and prompt through the network without the trigger word blank and use it as a target
|
||||||
|
self.diff_output_preservation = kwargs.get('diff_output_preservation', False)
|
||||||
|
self.diff_output_preservation_multiplier = kwargs.get('diff_output_preservation_multiplier', 1.0)
|
||||||
|
# If the trigger word is in the prompt, we will use this class name to replace it eg. "sks woman" -> "woman"
|
||||||
|
self.diff_output_preservation_class = kwargs.get('diff_output_preservation_class', '')
|
||||||
|
|
||||||
# legacy
|
# legacy
|
||||||
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
||||||
|
|||||||
@@ -62,6 +62,20 @@ class PromptEmbeds:
|
|||||||
prompt_embeds.attention_mask = self.attention_mask.clone()
|
prompt_embeds.attention_mask = self.attention_mask.clone()
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
|
def expand_to_batch(self, batch_size):
|
||||||
|
pe = self.clone()
|
||||||
|
current_batch_size = pe.text_embeds.shape[0]
|
||||||
|
if current_batch_size == batch_size:
|
||||||
|
return pe
|
||||||
|
if current_batch_size != 1:
|
||||||
|
raise Exception("Can only expand batch size for batch size 1")
|
||||||
|
pe.text_embeds = pe.text_embeds.expand(batch_size, -1)
|
||||||
|
if pe.pooled_embeds is not None:
|
||||||
|
pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1)
|
||||||
|
if pe.attention_mask is not None:
|
||||||
|
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
|
||||||
|
return pe
|
||||||
|
|
||||||
|
|
||||||
class EncodedPromptPair:
|
class EncodedPromptPair:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -61,6 +61,10 @@ export const defaultJobConfig: JobConfig = {
|
|||||||
ema_decay: 0.99,
|
ema_decay: 0.99,
|
||||||
},
|
},
|
||||||
dtype: 'bf16',
|
dtype: 'bf16',
|
||||||
|
diff_output_preservation: false,
|
||||||
|
diff_output_preservation_multiplier: 1.0,
|
||||||
|
diff_output_preservation_class: 'person'
|
||||||
|
|
||||||
},
|
},
|
||||||
model: {
|
model: {
|
||||||
name_or_path: 'ostris/Flex.1-alpha',
|
name_or_path: 'ostris/Flex.1-alpha',
|
||||||
|
|||||||
@@ -275,7 +275,7 @@ export default function TrainingForm() {
|
|||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<Card title="Training Configuration">
|
<Card title="Training Configuration">
|
||||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
|
||||||
<div>
|
<div>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
label="Batch Size"
|
label="Batch Size"
|
||||||
@@ -384,6 +384,31 @@ export default function TrainingForm() {
|
|||||||
min={0}
|
min={0}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
<div>
|
||||||
|
<FormGroup label="Regularization">
|
||||||
|
<Checkbox
|
||||||
|
label="Differtial Output Preservation"
|
||||||
|
className="pt-1"
|
||||||
|
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
||||||
|
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
||||||
|
/>
|
||||||
|
</FormGroup>
|
||||||
|
<NumberInput
|
||||||
|
label="DFE Loss Multiplier"
|
||||||
|
className="pt-2"
|
||||||
|
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||||
|
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
|
||||||
|
placeholder="eg. 1.0"
|
||||||
|
min={0}
|
||||||
|
/>
|
||||||
|
<TextInput
|
||||||
|
label="DFE Preservation Class"
|
||||||
|
className="pt-2"
|
||||||
|
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||||
|
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
|
||||||
|
placeholder="eg. woman"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -100,6 +100,9 @@ export interface TrainConfig {
|
|||||||
optimizer_params: {
|
optimizer_params: {
|
||||||
weight_decay: number;
|
weight_decay: number;
|
||||||
};
|
};
|
||||||
|
diff_output_preservation: boolean;
|
||||||
|
diff_output_preservation_multiplier: number;
|
||||||
|
diff_output_preservation_class: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface QuantizeKwargsConfig {
|
export interface QuantizeKwargsConfig {
|
||||||
|
|||||||
Reference in New Issue
Block a user