From f6e16e582afa9dca89bf862ed793610c01095d86 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 25 Feb 2025 20:12:36 -0700 Subject: [PATCH] Added Differential Output Preservation Loss to trainer and ui --- extensions_built_in/sd_trainer/SDTrainer.py | 79 +++++++++++++++++++-- toolkit/config_modules.py | 6 ++ toolkit/prompt_utils.py | 14 ++++ ui/src/app/jobs/new/jobConfig.ts | 4 ++ ui/src/app/jobs/new/page.tsx | 27 ++++++- ui/src/types.ts | 3 + 6 files changed, 127 insertions(+), 6 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 044b7037..5e91c344 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -78,8 +78,20 @@ class SDTrainer(BaseSDTrainProcess): self.cached_blank_embeds: Optional[PromptEmbeds] = None self.cached_trigger_embeds: Optional[PromptEmbeds] = None + self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None self.dfe: Optional[DiffusionFeatureExtractor] = None + + if self.train_config.diff_output_preservation: + if self.trigger_word is None: + raise ValueError("diff_output_preservation requires a trigger_word to be set") + if self.network_config is None: + raise ValueError("diff_output_preservation requires a network to be set") + if self.train_config.train_text_encoder: + raise ValueError("diff_output_preservation is not supported with train_text_encoder") + + # always do a prior prediction when doing diff output preservation + self.do_prior_prediction = True def before_model_load(self): @@ -176,6 +188,8 @@ class SDTrainer(BaseSDTrainProcess): self.cached_blank_embeds = self.sd.encode_prompt("") if self.trigger_word is not None: self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word) + if self.train_config.diff_output_preservation: + self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) # move back to cpu self.sd.text_encoder_to('cpu') @@ -536,6 +550,19 @@ class SDTrainer(BaseSDTrainProcess): return loss + additional_loss + + def get_diff_output_preservation_loss( + self, + noise_pred: torch.Tensor, + noise: torch.Tensor, + noisy_latents: torch.Tensor, + timesteps: torch.Tensor, + batch: 'DataLoaderBatchDTO', + mask_multiplier: Union[torch.Tensor, float] = 1.0, + prior_pred: Union[torch.Tensor, None] = None, + **kwargs + ): + loss_target = self.train_config.loss_target def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): return batch @@ -872,8 +899,8 @@ class SDTrainer(BaseSDTrainProcess): was_adapter_active = self.adapter.is_active self.adapter.is_active = False - if self.train_config.unload_text_encoder: - raise ValueError("Prior predictions currently do not support unloading text encoder") + if self.train_config.unload_text_encoder and self.adapter is not None: + raise ValueError("Prior predictions currently do not support unloading text encoder with adapter") # do a prediction here so we can match its output with network multiplier set to 0.0 with torch.no_grad(): dtype = get_torch_dtype(self.train_config.dtype) @@ -1336,7 +1363,16 @@ class SDTrainer(BaseSDTrainProcess): dtype=dtype) if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False - + + if self.train_config.diff_output_preservation: + dop_prompts = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in conditioned_prompts] + dop_prompts_2 = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in prompt_2] + self.diff_output_preservation_embeds = self.sd.encode_prompt( + dop_prompts, dop_prompts_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype) # detach the embeddings conditional_embeds = conditional_embeds.detach() if self.train_config.do_cfg: @@ -1524,9 +1560,14 @@ class SDTrainer(BaseSDTrainProcess): if (( has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): with self.timer('prior predict'): + prior_embeds_to_use = conditional_embeds + # use diff_output_preservation embeds if doing dfe + if self.train_config.diff_output_preservation: + prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) + prior_pred = self.get_prior_prediction( noisy_latents=noisy_latents, - conditional_embeds=conditional_embeds, + conditional_embeds=prior_embeds_to_use, match_adapter_assist=match_adapter_assist, network_weight_list=network_weight_list, timesteps=timesteps, @@ -1627,6 +1668,12 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('calculate_loss'): noise = noise.to(self.device_torch, dtype=dtype).detach() + prior_to_calculate_loss = prior_pred + # if we are doing diff_output_preservation and not noing inverted masked prior + # then we need to send none here so it will not target the prior + if self.train_config.diff_output_preservation and not do_inverted_masked_prior: + prior_to_calculate_loss = None + loss = self.calculate_loss( noise_pred=noise_pred, noise=noise, @@ -1634,8 +1681,30 @@ class SDTrainer(BaseSDTrainProcess): timesteps=timesteps, batch=batch, mask_multiplier=mask_multiplier, - prior_pred=prior_pred, + prior_pred=prior_to_calculate_loss, ) + + if self.train_config.diff_output_preservation: + # send the loss backwards otherwise checkpointing will fail + self.accelerator.backward(loss) + normal_loss = loss.detach() # dont send backward again + + dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) + dop_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + **pred_kwargs + ) + dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier + self.accelerator.backward(dop_loss) + + loss = normal_loss + dop_loss + loss = loss.clone().detach() + # require grad again so the backward wont fail + loss.requires_grad_(True) + # check if nan if torch.isnan(loss): print_acc("loss is nan") diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ee4b6e72..3fc7728c 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -341,6 +341,12 @@ class TrainConfig: # unmasked reign. It is unmasked regularization basically self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False) self.inverted_mask_prior_multiplier = kwargs.get('inverted_mask_prior_multiplier', 0.5) + + # DOP will will run the same image and prompt through the network without the trigger word blank and use it as a target + self.diff_output_preservation = kwargs.get('diff_output_preservation', False) + self.diff_output_preservation_multiplier = kwargs.get('diff_output_preservation_multiplier', 1.0) + # If the trigger word is in the prompt, we will use this class name to replace it eg. "sks woman" -> "woman" + self.diff_output_preservation_class = kwargs.get('diff_output_preservation_class', '') # legacy if match_adapter_assist and self.match_adapter_chance == 0.0: diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index a145841c..52e15907 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -62,6 +62,20 @@ class PromptEmbeds: prompt_embeds.attention_mask = self.attention_mask.clone() return prompt_embeds + def expand_to_batch(self, batch_size): + pe = self.clone() + current_batch_size = pe.text_embeds.shape[0] + if current_batch_size == batch_size: + return pe + if current_batch_size != 1: + raise Exception("Can only expand batch size for batch size 1") + pe.text_embeds = pe.text_embeds.expand(batch_size, -1) + if pe.pooled_embeds is not None: + pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1) + if pe.attention_mask is not None: + pe.attention_mask = pe.attention_mask.expand(batch_size, -1) + return pe + class EncodedPromptPair: def __init__( diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index 91b65540..d87d12c4 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -61,6 +61,10 @@ export const defaultJobConfig: JobConfig = { ema_decay: 0.99, }, dtype: 'bf16', + diff_output_preservation: false, + diff_output_preservation_multiplier: 1.0, + diff_output_preservation_class: 'person' + }, model: { name_or_path: 'ostris/Flex.1-alpha', diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx index 7eb1180b..e3aec77f 100644 --- a/ui/src/app/jobs/new/page.tsx +++ b/ui/src/app/jobs/new/page.tsx @@ -275,7 +275,7 @@ export default function TrainingForm() {
-
+
+
+ + setJobConfig(value, 'config.process[0].train.diff_output_preservation')} + /> + + setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')} + placeholder="eg. 1.0" + min={0} + /> + setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')} + placeholder="eg. woman" + /> +
diff --git a/ui/src/types.ts b/ui/src/types.ts index 42c737cd..16ecb220 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -100,6 +100,9 @@ export interface TrainConfig { optimizer_params: { weight_decay: number; }; + diff_output_preservation: boolean; + diff_output_preservation_multiplier: number; + diff_output_preservation_class: string; } export interface QuantizeKwargsConfig {