mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-02 17:19:48 +00:00
Added target_norm_std which is a game changer
This commit is contained in:
@@ -454,6 +454,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = loss + self.adapter.additional_loss.mean()
|
||||
self.adapter.additional_loss = None
|
||||
|
||||
if self.train_config.target_norm_std:
|
||||
# seperate out the batch and channels
|
||||
pred_std = noise_pred.std([2, 3], keepdim=True)
|
||||
norm_std_loss = torch.abs(1.0 - pred_std).mean()
|
||||
loss = loss + norm_std_loss
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
|
||||
@@ -352,6 +352,9 @@ class TrainConfig:
|
||||
|
||||
self.ema_config: EMAConfig = EMAConfig(**ema_config)
|
||||
|
||||
# adds an additional loss to the network to encourage it output a normalized standard deviation
|
||||
self.target_norm_std = kwargs.get('target_norm_std', None)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -27,7 +27,6 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.models.auraflow import patch_auraflow_pos_embed
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
|
||||
@@ -751,7 +751,6 @@ def encode_prompts_auraflow(
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
||||
text_input_ids = text_inputs["input_ids"]
|
||||
untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
@@ -760,6 +759,7 @@ def encode_prompts_auraflow(
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1])
|
||||
|
||||
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
||||
prompt_embeds = text_encoder(**text_inputs)[0]
|
||||
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
|
||||
prompt_embeds = prompt_embeds * prompt_attention_mask
|
||||
|
||||
Reference in New Issue
Block a user