Added target_norm_std which is a game changer

This commit is contained in:
Jaret Burkett
2024-07-28 16:08:33 -06:00
parent 0bc4d555c7
commit e81e19fd0f
4 changed files with 11 additions and 2 deletions

View File

@@ -454,6 +454,13 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss + self.adapter.additional_loss.mean()
self.adapter.additional_loss = None
if self.train_config.target_norm_std:
# seperate out the batch and channels
pred_std = noise_pred.std([2, 3], keepdim=True)
norm_std_loss = torch.abs(1.0 - pred_std).mean()
loss = loss + norm_std_loss
return loss
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):

View File

@@ -352,6 +352,9 @@ class TrainConfig:
self.ema_config: EMAConfig = EMAConfig(**ema_config)
# adds an additional loss to the network to encourage it output a normalized standard deviation
self.target_norm_std = kwargs.get('target_norm_std', None)
class ModelConfig:
def __init__(self, **kwargs):

View File

@@ -27,7 +27,6 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.auraflow import patch_auraflow_pos_embed
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
from toolkit.reference_adapter import ReferenceAdapter

View File

@@ -751,7 +751,6 @@ def encode_prompts_auraflow(
padding="max_length",
return_tensors="pt",
)
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
text_input_ids = text_inputs["input_ids"]
untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
@@ -760,6 +759,7 @@ def encode_prompts_auraflow(
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1])
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
prompt_embeds = text_encoder(**text_inputs)[0]
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
prompt_embeds = prompt_embeds * prompt_attention_mask