mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
added a method to apply multipliers to noise and latents prior to combining
This commit is contained in:
@@ -5,6 +5,7 @@ from collections import OrderedDict
|
||||
import os
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
from diffusers import T2IAdapter
|
||||
# from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -538,6 +539,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
noise_multiplier = self.train_config.noise_multiplier
|
||||
|
||||
noise = noise * noise_multiplier
|
||||
|
||||
img_multiplier = self.train_config.img_multiplier
|
||||
|
||||
latents = latents * img_multiplier
|
||||
|
||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# remove grads for these
|
||||
@@ -997,10 +1006,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
|
||||
# apply network normalizer if we are using it, not on regularization steps
|
||||
if self.network is not None and self.network.is_normalizing and not is_reg_step:
|
||||
with self.timer('apply_normalizer'):
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
with self.timer('batch_cleanup'):
|
||||
batch.cleanup()
|
||||
|
||||
# don't do on first step
|
||||
if self.step_num != self.start_step:
|
||||
if is_sample_step:
|
||||
self.progress_bar.pause()
|
||||
flush()
|
||||
# print above the progress bar
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.disable_freeu()
|
||||
@@ -1036,16 +1056,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# end of step
|
||||
self.step_num = step
|
||||
|
||||
# apply network normalizer if we are using it, not on regularization steps
|
||||
if self.network is not None and self.network.is_normalizing and not is_reg_step:
|
||||
with self.timer('apply_normalizer'):
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
with self.timer('batch_cleanup'):
|
||||
batch.cleanup()
|
||||
|
||||
# flush every 10 steps
|
||||
# if self.step_num % 10 == 0:
|
||||
# flush()
|
||||
|
||||
@@ -122,6 +122,8 @@ class TrainConfig:
|
||||
self.start_step = kwargs.get('start_step', None)
|
||||
self.free_u = kwargs.get('free_u', False)
|
||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
|
||||
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
||||
|
||||
Reference in New Issue
Block a user