diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 5e562112..e4468575 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -454,6 +454,13 @@ class SDTrainer(BaseSDTrainProcess): loss = loss + self.adapter.additional_loss.mean() self.adapter.additional_loss = None + if self.train_config.target_norm_std: + # seperate out the batch and channels + pred_std = noise_pred.std([2, 3], keepdim=True) + norm_std_loss = torch.abs(1.0 - pred_std).mean() + loss = loss + norm_std_loss + + return loss def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f6324364..af95e9b7 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -352,6 +352,9 @@ class TrainConfig: self.ema_config: EMAConfig = EMAConfig(**ema_config) + # adds an additional loss to the network to encourage it output a normalized standard deviation + self.target_norm_std = kwargs.get('target_norm_std', None) + class ModelConfig: def __init__(self, **kwargs): diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index ea81ccdb..a728c111 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -27,7 +27,6 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod from toolkit import train_tools from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors -from toolkit.models.auraflow import patch_auraflow_pos_embed from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds from toolkit.reference_adapter import ReferenceAdapter diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 7d492441..6a71e558 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -751,7 +751,6 @@ def encode_prompts_auraflow( padding="max_length", return_tensors="pt", ) - text_inputs = {k: v.to(device) for k, v in text_inputs.items()} text_input_ids = text_inputs["input_ids"] untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids @@ -760,6 +759,7 @@ def encode_prompts_auraflow( ): removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) + text_inputs = {k: v.to(device) for k, v in text_inputs.items()} prompt_embeds = text_encoder(**text_inputs)[0] prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) prompt_embeds = prompt_embeds * prompt_attention_mask