added a method to apply multipliers to noise and latents prior to combining

This commit is contained in:
Jaret Burkett
2023-10-17 06:09:16 -06:00
parent a05459afaf
commit da6302ada8
2 changed files with 22 additions and 10 deletions

View File

@@ -5,6 +5,7 @@ from collections import OrderedDict
import os
from typing import Union, List
import numpy as np
from diffusers import T2IAdapter
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
@@ -538,6 +539,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
noise_multiplier = self.train_config.noise_multiplier
noise = noise * noise_multiplier
img_multiplier = self.train_config.img_multiplier
latents = latents * img_multiplier
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
# remove grads for these
@@ -997,10 +1006,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.progress_bar.set_postfix_str(prog_bar_string)
# apply network normalizer if we are using it, not on regularization steps
if self.network is not None and self.network.is_normalizing and not is_reg_step:
with self.timer('apply_normalizer'):
self.network.apply_stored_normalizer()
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
with self.timer('batch_cleanup'):
batch.cleanup()
# don't do on first step
if self.step_num != self.start_step:
if is_sample_step:
self.progress_bar.pause()
flush()
# print above the progress bar
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
@@ -1036,16 +1056,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
# end of step
self.step_num = step
# apply network normalizer if we are using it, not on regularization steps
if self.network is not None and self.network.is_normalizing and not is_reg_step:
with self.timer('apply_normalizer'):
self.network.apply_stored_normalizer()
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
with self.timer('batch_cleanup'):
batch.cleanup()
# flush every 10 steps
# if self.step_num % 10 == 0:
# flush()

View File

@@ -122,6 +122,8 @@ class TrainConfig:
self.start_step = kwargs.get('start_step', None)
self.free_u = kwargs.get('free_u', False)
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
match_adapter_assist = kwargs.get('match_adapter_assist', False)
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)